From 7f20caf208639c1ff352c3b765b048ae281c92a6 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Wed, 23 Jul 2025 20:11:38 -0300 Subject: [PATCH] Moar progress --- Cargo.lock | 1 + crates/acp_thread/src/acp_thread.rs | 178 +++++----- crates/acp_thread/src/connection.rs | 110 ++++-- crates/agent_servers/Cargo.toml | 1 + crates/agent_servers/src/agent_servers.rs | 6 +- crates/agent_servers/src/claude.rs | 333 +++++++++--------- crates/agent_servers/src/claude/tools.rs | 190 +++++----- crates/agent_servers/src/mcp_server.rs | 7 +- .../agent_servers/src/stdio_agent_server.rs | 93 ++--- crates/agent_ui/src/acp/thread_view.rs | 40 ++- 10 files changed, 506 insertions(+), 453 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0758ecbdf49e7f381e3b4e5f17ae4792893bc595..fc922bc60eeb61beafc36f85404177b6d69adb68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -150,6 +150,7 @@ name = "agent_servers" version = "0.1.0" dependencies = [ "acp_thread", + "agent-client-protocol", "agentic-coding-protocol", "anyhow", "collections", diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 87f7c23c5bcde70a808e2aa9e7e6f1804d653da5..15855c8a996ff8244ebaead3a241d6c57c559847 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -258,7 +258,7 @@ impl Display for ToolCallStatus { } #[derive(Debug, PartialEq, Clone)] -enum ContentBlock { +pub enum ContentBlock { Empty, Markdown { markdown: Entity }, } @@ -599,8 +599,10 @@ impl Error for LoadError {} impl AcpThread { pub fn new( - connection: impl AgentConnection + 'static, + connection: Arc, + // todo! remove me title: SharedString, + // todo! remove this? child_status: Option>>, project: Entity, session_id: acp::SessionId, @@ -616,7 +618,7 @@ impl AcpThread { title, project, send_task: None, - connection: Arc::new(connection), + connection, child_status, session_id, } @@ -712,37 +714,47 @@ impl AcpThread { pub fn update_tool_call( &mut self, - tool_call: acp::ToolCall, + id: acp::ToolCallId, + status: acp::ToolCallStatus, + content: Option>, cx: &mut Context, ) -> Result<()> { + let languages = self.project.read(cx).languages().clone(); + let (ix, current_call) = self.tool_call_mut(&id).context("Tool call not found")?; + + if let Some(content) = content { + current_call.content = content + .into_iter() + .map(|chunk| ToolCallContent::from_acp(chunk, languages.clone(), cx)) + .collect(); + } + current_call.status = ToolCallStatus::Allowed { status }; + + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + + Ok(()) + } + + /// Updates a tool call if id matches an existing entry, otherwise inserts a new one. + pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context) { let status = ToolCallStatus::Allowed { status: tool_call.status, }; - self.update_tool_call_inner(tool_call, status, cx) + self.upsert_tool_call_inner(tool_call, status, cx) } - pub fn update_tool_call_inner( + pub fn upsert_tool_call_inner( &mut self, tool_call: acp::ToolCall, status: ToolCallStatus, cx: &mut Context, - ) -> Result<()> { + ) { let language_registry = self.project.read(cx).languages().clone(); let call = ToolCall::from_acp(tool_call, status, language_registry, cx); let location = call.locations.last().cloned(); if let Some((ix, current_call)) = self.tool_call_mut(&call.id) { - match ¤t_call.status { - ToolCallStatus::WaitingForConfirmation { .. } => { - anyhow::bail!("Tool call hasn't been authorized yet") - } - ToolCallStatus::Rejected => { - anyhow::bail!("Tool call was rejected and therefore can't be updated") - } - ToolCallStatus::Allowed { .. } | ToolCallStatus::Canceled => {} - } - *current_call = call; cx.emit(AcpThreadEvent::EntryUpdated(ix)); @@ -753,25 +765,6 @@ impl AcpThread { if let Some(location) = location { self.set_project_location(location, cx) } - - Ok(()) - } - - fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> { - // todo! use map - self.entries - .iter() - .enumerate() - .rev() - .find_map(|(index, tool_call)| { - if let AgentThreadEntry::ToolCall(tool_call) = tool_call - && &tool_call.id == id - { - Some((index, tool_call)) - } else { - None - } - }) } fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> { @@ -804,7 +797,7 @@ impl AcpThread { respond_tx: tx, }; - self.update_tool_call_inner(tool_call, status, cx); + self.upsert_tool_call_inner(tool_call, status, cx); rx } @@ -913,8 +906,8 @@ impl AcpThread { false } - pub fn authenticate(&self) -> impl use<> + Future> { - self.connection.authenticate() + pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future> { + self.connection.authenticate(cx) } #[cfg(any(test, feature = "test-support"))] @@ -948,18 +941,23 @@ impl AcpThread { ); let (tx, rx) = oneshot::channel(); - let cancel = self.cancel(cx); + self.cancel(cx); + let old_send = self.send_task.take(); self.send_task = Some(cx.spawn(async move |this, cx| { async { - cancel.await.log_err(); - + if let Some(old_send) = old_send { + old_send.await; + } let result = this - .update(cx, |this, _| { - this.connection.prompt(acp::PromptToolArguments { - prompt: message, - session_id: this.session_id.clone(), - }) + .update(cx, |this, cx| { + this.connection.prompt( + acp::PromptToolArguments { + prompt: message, + session_id: this.session_id.clone(), + }, + cx, + ) })? .await; tx.send(result).log_err(); @@ -979,32 +977,25 @@ impl AcpThread { .boxed() } - pub fn cancel(&mut self, cx: &mut Context) -> Task> { - if self.send_task.take().is_some() { - let request = self.connection.cancel(); - cx.spawn(async move |this, cx| { - request.await?; - this.update(cx, |this, _cx| { - for entry in this.entries.iter_mut() { - if let AgentThreadEntry::ToolCall(call) = entry { - let cancel = matches!( - call.status, - ToolCallStatus::WaitingForConfirmation { .. } - | ToolCallStatus::Allowed { - status: acp::ToolCallStatus::InProgress - } - ); - - if cancel { - call.status = ToolCallStatus::Canceled; - } + pub fn cancel(&mut self, cx: &mut Context) { + if self.send_task.take().is_none() { + return; + } + self.connection.cancel(cx); + for entry in self.entries.iter_mut() { + if let AgentThreadEntry::ToolCall(call) = entry { + let cancel = matches!( + call.status, + ToolCallStatus::WaitingForConfirmation { .. } + | ToolCallStatus::Allowed { + status: acp::ToolCallStatus::InProgress } - } - })?; - Ok(()) - }) - } else { - Task::ready(Ok(())) + ); + + if cancel { + call.status = ToolCallStatus::Canceled; + } + } } } @@ -1160,14 +1151,14 @@ impl AcpThread { #[derive(Clone)] pub struct OldAcpClientDelegate { - thread: WeakEntity, + thread: Rc>>, cx: AsyncApp, next_tool_call_id: Rc>, // sent_buffer_versions: HashMap, HashMap>, } impl OldAcpClientDelegate { - pub fn new(thread: WeakEntity, cx: AsyncApp) -> Self { + pub fn new(thread: Rc>>, cx: AsyncApp) -> Self { Self { thread, cx, @@ -1179,6 +1170,7 @@ impl OldAcpClientDelegate { let cx = &mut self.cx.clone(); cx.update(|cx| { self.thread + .borrow() .update(cx, |thread, cx| thread.clear_completed_plan_entries(cx)) })? .context("Failed to update thread")?; @@ -1193,7 +1185,7 @@ impl OldAcpClientDelegate { let content = self .cx .update(|cx| { - self.thread.update(cx, |thread, cx| { + self.thread.borrow().update(cx, |thread, cx| { thread.read_text_file( acp::ReadTextFileArguments { path: request.path, @@ -1219,6 +1211,7 @@ impl acp_old::Client for OldAcpClientDelegate { cx.update(|cx| { self.thread + .borrow() .update(cx, |thread, cx| match params.chunk { acp_old::AssistantMessageChunk::Text { text } => { thread.push_assistant_chunk(text.into(), false, cx) @@ -1313,7 +1306,7 @@ impl acp_old::Client for OldAcpClientDelegate { let response = cx .update(|cx| { - self.thread.update(cx, |thread, cx| { + self.thread.borrow().update(cx, |thread, cx| { thread.request_tool_call_permission(tool_call, acp_options, cx) }) })? @@ -1341,14 +1334,14 @@ impl acp_old::Client for OldAcpClientDelegate { self.next_tool_call_id.replace(old_acp_id); cx.update(|cx| { - self.thread.update(cx, |thread, cx| { - thread.update_tool_call( + self.thread.borrow().update(cx, |thread, cx| { + thread.upsert_tool_call( into_new_tool_call(acp::ToolCallId(old_acp_id.to_string().into()), request), cx, ) }) })? - .context("Failed to update thread")??; + .context("Failed to update thread")?; Ok(acp_old::PushToolCallResponse { id: acp_old::ToolCallId(old_acp_id), @@ -1362,7 +1355,7 @@ impl acp_old::Client for OldAcpClientDelegate { let cx = &mut self.cx.clone(); cx.update(|cx| { - self.thread.update(cx, |thread, cx| { + self.thread.borrow().update(cx, |thread, cx| { let languages = thread.project.read(cx).languages().clone(); if let Some((ix, tool_call)) = thread @@ -1399,7 +1392,7 @@ impl acp_old::Client for OldAcpClientDelegate { let cx = &mut self.cx.clone(); cx.update(|cx| { - self.thread.update(cx, |thread, cx| { + self.thread.borrow().update(cx, |thread, cx| { thread.update_plan( acp::Plan { entries: request @@ -1424,7 +1417,7 @@ impl acp_old::Client for OldAcpClientDelegate { let content = self .cx .update(|cx| { - self.thread.update(cx, |thread, cx| { + self.thread.borrow().update(cx, |thread, cx| { thread.read_text_file( acp::ReadTextFileArguments { path: request.path, @@ -1447,7 +1440,7 @@ impl acp_old::Client for OldAcpClientDelegate { ) -> Result<(), acp_old::Error> { self.cx .update(|cx| { - self.thread.update(cx, |thread, cx| { + self.thread.borrow().update(cx, |thread, cx| { thread.write_text_file( acp::WriteTextFileToolArguments { path: request.path, @@ -1782,10 +1775,7 @@ mod tests { cx.run_until_parked(); - thread - .update(cx, |thread, cx| thread.cancel(cx)) - .await - .unwrap(); + thread.update(cx, |thread, cx| thread.cancel(cx)); thread.read_with(cx, |thread, _| { assert!(matches!( @@ -1861,8 +1851,10 @@ mod tests { let thread = cx.new(|cx| { let foreground_executor = cx.foreground_executor().clone(); + let thread_rc = Rc::new(RefCell::new(cx.entity().downgrade())); + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), + OldAcpClientDelegate::new(thread_rc.clone(), cx.to_async()), stdin_tx, stdout_rx, move |fut| { @@ -1876,10 +1868,16 @@ mod tests { Ok(()) } }); - AcpThread::new( + let connection = OldAcpAgentConnection { connection, + child_status: io_task, + thread: thread_rc, + }; + + AcpThread::new( + Arc::new(connection), "Test".into(), - Some(io_task), + None, project, acp::SessionId("test".into()), cx, diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index e3088b3f2e6cc4a3dd72140b0d65dd36aa2421b6..c2f21a59bf58f448d96d76b839c953c48dd20606 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -1,55 +1,91 @@ +use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc, sync::Arc}; + use agent_client_protocol as acp; use agentic_coding_protocol::{self as acp_old, AgentRequest}; use anyhow::Result; -use futures::future::{FutureExt as _, LocalBoxFuture}; +use gpui::{AppContext, Entity, Task, WeakEntity}; +use project::Project; +use ui::App; + +use crate::AcpThread; pub trait AgentConnection { - fn new_session( + fn new_thread( &self, - params: acp::NewSessionToolArguments, - ) -> LocalBoxFuture<'static, Result>; + project: Entity, + cwd: &Path, + connection: Arc, + cx: &mut App, + ) -> Task>>; - fn authenticate(&self) -> LocalBoxFuture<'static, Result<()>>; + fn authenticate(&self, cx: &mut App) -> Task>; - fn prompt(&self, params: acp::PromptToolArguments) -> LocalBoxFuture<'static, Result<()>>; + fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task>; - fn cancel(&self) -> LocalBoxFuture<'static, Result<()>>; + fn cancel(&self, cx: &mut App); } -impl AgentConnection for acp_old::AgentConnection { - fn new_session( +#[derive(Debug)] +pub struct Unauthenticated; + +impl Error for Unauthenticated {} +impl fmt::Display for Unauthenticated { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Unauthenticated") + } +} + +pub struct OldAcpAgentConnection { + pub connection: acp_old::AgentConnection, + pub child_status: Task>, + pub thread: Rc>>, +} + +impl AgentConnection for OldAcpAgentConnection { + fn new_thread( &self, - _params: acp::NewSessionToolArguments, - ) -> LocalBoxFuture<'static, Result> { - let task = self.request_any( + project: Entity, + _cwd: &Path, + connection: Arc, + cx: &mut App, + ) -> Task>> { + let task = self.connection.request_any( acp_old::InitializeParams { protocol_version: acp_old::ProtocolVersion::latest(), } .into_any(), ); - async move { + let current_thread = self.thread.clone(); + cx.spawn(async move |cx| { let result = task.await?; let result = acp_old::InitializeParams::response_from_any(result)?; if !result.is_authenticated { - anyhow::bail!("Not authenticated"); + anyhow::bail!(Unauthenticated) } - Ok(acp::SessionId("acp-old-no-id".into())) - } - .boxed_local() + cx.update(|cx| { + let thread = cx.new(|cx| { + let session_id = acp::SessionId("acp-old-no-id".into()); + AcpThread::new(connection, "Gemini".into(), None, project, session_id, cx) + }); + current_thread.replace(thread.downgrade()); + thread + }) + }) } - fn authenticate(&self) -> LocalBoxFuture<'static, Result<()>> { - let task = self.request_any(acp_old::AuthenticateParams.into_any()); - async move { + fn authenticate(&self, cx: &mut App) -> Task> { + let task = self + .connection + .request_any(acp_old::AuthenticateParams.into_any()); + cx.foreground_executor().spawn(async move { task.await?; - anyhow::Ok(()) - } - .boxed_local() + Ok(()) + }) } - fn prompt(&self, params: acp::PromptToolArguments) -> LocalBoxFuture<'static, Result<()>> { + fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task> { let chunks = params .prompt .into_iter() @@ -64,20 +100,24 @@ impl AgentConnection for acp_old::AgentConnection { }) .collect(); - let task = self.request_any(acp_old::SendUserMessageParams { chunks }.into_any()); - async move { + let task = self + .connection + .request_any(acp_old::SendUserMessageParams { chunks }.into_any()); + cx.foreground_executor().spawn(async move { task.await?; anyhow::Ok(()) - } - .boxed_local() + }) } - fn cancel(&self) -> LocalBoxFuture<'static, Result<()>> { - let task = self.request_any(acp_old::CancelSendMessageParams.into_any()); - async move { - task.await?; - anyhow::Ok(()) - } - .boxed_local() + fn cancel(&self, cx: &mut App) { + let task = self + .connection + .request_any(acp_old::CancelSendMessageParams.into_any()); + cx.foreground_executor() + .spawn(async move { + task.await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx) } } diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 0552677a27e61462b14d733bce93004d291fc723..75d682012bebc0576997c889a2065d1a0150be7f 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -18,6 +18,7 @@ doctest = false [dependencies] acp_thread.workspace = true +agent-client-protocol.workspace = true agentic-coding-protocol.workspace = true anyhow.workspace = true collections.workspace = true diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 48536e13ac28d0c16e59d7eb021d5ab39a850cf4..1e29074e6554f1e6f07a084ea166a7250784a277 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -14,7 +14,7 @@ pub use gemini::*; pub use settings::*; pub use stdio_agent_server::*; -use acp_thread::AcpThread; +use acp_thread::AgentConnection; use anyhow::Result; use collections::HashMap; use gpui::{App, AsyncApp, Entity, SharedString, Task}; @@ -38,12 +38,12 @@ pub trait AgentServer: Send { fn empty_state_message(&self) -> &'static str; fn supports_always_allow(&self) -> bool; - fn new_thread( + fn connect( &self, root_dir: &Path, project: &Entity, cx: &mut App, - ) -> Task>>; + ) -> Task>>; } impl std::fmt::Debug for AgentServerCommand { diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 86321bcb982c4322b471b6fcce324757d0c74a47..79b02ddf74b02d957f7e2c706c46c6fa8e4656ed 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -9,27 +9,27 @@ use std::fmt::Display; use std::path::Path; use std::pin::pin; use std::rc::Rc; +use std::sync::Arc; use uuid::Uuid; -use agentic_coding_protocol as acp_old; +use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use futures::channel::oneshot; -use futures::future::LocalBoxFuture; -use futures::{AsyncBufReadExt, AsyncWriteExt, SinkExt}; +use futures::{AsyncBufReadExt, AsyncWriteExt}; use futures::{ AsyncRead, AsyncWrite, FutureExt, StreamExt, channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, io::BufReader, select_biased, }; -use gpui::{App, AppContext, Entity, Task}; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; use serde::{Deserialize, Serialize}; use util::ResultExt; use crate::claude::tools::ClaudeTool; use crate::mcp_server::{self, McpConfig, ZedMcpServer}; use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings}; -use acp_thread::{AcpThread, AgentConnection, OldAcpClientDelegate}; +use acp_thread::{AcpThread, AgentConnection}; #[derive(Clone)] pub struct ClaudeCode; @@ -55,21 +55,20 @@ impl AgentServer for ClaudeCode { false } - fn new_thread( + fn connect( &self, root_dir: &Path, project: &Entity, cx: &mut App, - ) -> Task>> { + ) -> Task>> { let project = project.clone(); let root_dir = root_dir.to_path_buf(); - let title = self.name().into(); cx.spawn(async move |cx| { - let (mut delegate_tx, delegate_rx) = watch::channel(None); + let mut threads_map = Rc::new(RefCell::new(HashMap::default())); let tool_id_map = Rc::new(RefCell::new(HashMap::default())); let permission_mcp_server = - ZedMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?; + ZedMcpServer::new(threads_map, tool_id_map.clone(), cx).await?; let mut mcp_servers = HashMap::default(); mcp_servers.insert( @@ -101,7 +100,7 @@ impl AgentServer for ClaudeCode { let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); let (cancel_tx, mut cancel_rx) = mpsc::unbounded::>>(); - let session_id = Uuid::new_v4(); + let session_id = acp::SessionId(Uuid::new_v4().to_string().into()); log::trace!("Starting session with id: {}", session_id); @@ -152,41 +151,33 @@ impl AgentServer for ClaudeCode { }) .detach(); - cx.new(|cx| { - let end_turn_tx = Rc::new(RefCell::new(None)); - let delegate = OldAcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()); - delegate_tx.send(Some(delegate.clone())).log_err(); - - let handler_task = cx.foreground_executor().spawn({ - let end_turn_tx = end_turn_tx.clone(); - let tool_id_map = tool_id_map.clone(); - let delegate = delegate.clone(); - async move { - while let Some(message) = incoming_message_rx.next().await { - ClaudeAgentConnection::handle_message( - delegate.clone(), - message, - end_turn_tx.clone(), - tool_id_map.clone(), - ) - .await - } + let end_turn_tx = Rc::new(RefCell::new(None)); + let handler_task = cx.spawn({ + let end_turn_tx = end_turn_tx.clone(); + async move |cx| { + while let Some(message) = incoming_message_rx.next().await { + ClaudeAgentConnection::handle_message( + threads_map.clone(), + message, + end_turn_tx.clone(), + cx, + ) + .await } - }); - - let mut connection = ClaudeAgentConnection { - delegate, - outgoing_tx, - end_turn_tx, - cancel_tx, - session_id, - _handler_task: handler_task, - _mcp_server: None, - }; + } + }); + + let connection = ClaudeAgentConnection { + threads_map, + outgoing_tx, + end_turn_tx, + cancel_tx, + session_id, + _handler_task: handler_task, + _mcp_server: Some(permission_mcp_server), + }; - connection._mcp_server = Some(permission_mcp_server); - acp_thread::AcpThread::new(connection, title, None, project.clone(), cx) - }) + Ok(Arc::new(connection) as _) }) } } @@ -205,71 +196,84 @@ fn send_interrupt(_pid: i32) -> anyhow::Result<()> { } impl AgentConnection for ClaudeAgentConnection { - /// Send a request to the agent and wait for a response. - fn request_any( + fn new_thread( &self, - params: acp_old::AnyAgentRequest, - ) -> LocalBoxFuture<'static, Result> { - let delegate = self.delegate.clone(); - let end_turn_tx = self.end_turn_tx.clone(); - let outgoing_tx = self.outgoing_tx.clone(); - let mut cancel_tx = self.cancel_tx.clone(); - let session_id = self.session_id; - async move { - match params { - // todo: consider sending an empty request so we get the init response? - acp_old::AnyAgentRequest::InitializeParams(_) => Ok( - acp_old::AnyAgentResult::InitializeResponse(acp::InitializeResponse { - is_authenticated: true, - protocol_version: acp_old::ProtocolVersion::latest(), - }), - ), - acp_old::AnyAgentRequest::AuthenticateParams(_) => { - Err(anyhow!("Authentication not supported")) + project: Entity, + _cwd: &Path, + connection: Arc, + cx: &mut App, + ) -> Task>> { + let session_id = self.session_id.clone(); + let thread = + cx.new(|cx| AcpThread::new(connection, "Claude".into(), None, project, session_id, cx)); + Task::ready(Ok(thread)) + } + + fn authenticate(&self, _cx: &mut App) -> Task> { + Task::ready(Err(anyhow!("Authentication not supported"))) + } + + fn prompt(&self, params: acp::PromptToolArguments, cx: &mut App) -> Task> { + let Some(thread) = self + .threads_map + .borrow() + .get(¶ms.session_id) + .and_then(|entity| entity.upgrade()) + else { + return Task::ready(Err(anyhow!("Thread not found"))); + }; + + thread.update(cx, |thread, cx| { + thread.clear_completed_plan_entries(cx); + }); + + let (tx, rx) = oneshot::channel(); + self.end_turn_tx.borrow_mut().replace(tx); + + let mut content = String::new(); + for chunk in params.prompt { + match chunk { + acp::ContentBlock::Text(text_content) => { + content.push_str(&text_content.text); } - acp_old::AnyAgentRequest::SendUserMessageParams(message) => { - delegate.clear_completed_plan_entries().await?; - - let (tx, rx) = oneshot::channel(); - end_turn_tx.borrow_mut().replace(tx); - let mut content = String::new(); - for chunk in message.chunks { - match chunk { - acp_old::UserMessageChunk::Text { text } => content.push_str(&text), - acp_old::UserMessageChunk::Path { path } => { - content.push_str(&format!("@{path:?}")) - } - } - } - outgoing_tx.unbounded_send(SdkMessage::User { - message: Message { - role: Role::User, - content: Content::UntaggedText(content), - id: None, - model: None, - stop_reason: None, - stop_sequence: None, - usage: None, - }, - session_id: Some(session_id), - })?; - rx.await??; - Ok(acp_old::AnyAgentResult::SendUserMessageResponse( - acp::SendUserMessageResponse, - )) + acp::ContentBlock::ResourceLink(resource_link) => { + content.push_str(&format!("@{}", resource_link.uri)); } - acp_old::AnyAgentRequest::CancelSendMessageParams(_) => { - let (done_tx, done_rx) = oneshot::channel(); - cancel_tx.send(done_tx).await?; - done_rx.await??; - - Ok(acp_old::AnyAgentResult::CancelSendMessageResponse( - acp::CancelSendMessageResponse, - )) + acp::ContentBlock::Audio(_) + | acp::ContentBlock::Image(_) + | acp::ContentBlock::Resource(_) => { + // TODO } } } - .boxed_local() + + if let Err(err) = self.outgoing_tx.unbounded_send(SdkMessage::User { + message: Message { + role: Role::User, + content: Content::UntaggedText(content), + id: None, + model: None, + stop_reason: None, + stop_sequence: None, + usage: None, + }, + session_id: Some(params.session_id.to_string()), + }) { + return Task::ready(Err(anyhow!(err))); + } + + cx.foreground_executor().spawn(async move { + rx.await??; + Ok(()) + }) + } + + fn cancel(&self, cx: &mut App) { + let (done_tx, done_rx) = oneshot::channel(); + self.cancel_tx.unbounded_send(done_tx); + cx.foreground_executor() + .spawn(async move { done_rx.await? }) + .detach_and_log_err(cx); } } @@ -282,7 +286,7 @@ enum ClaudeSessionMode { async fn spawn_claude( command: &AgentServerCommand, mode: ClaudeSessionMode, - session_id: Uuid, + session_id: acp::SessionId, mcp_config_path: &Path, root_dir: &Path, ) -> Result { @@ -323,8 +327,8 @@ async fn spawn_claude( } struct ClaudeAgentConnection { - delegate: OldAcpClientDelegate, - session_id: Uuid, + threads_map: Rc>>>, + session_id: acp::SessionId, outgoing_tx: UnboundedSender, end_turn_tx: Rc>>>>, cancel_tx: UnboundedSender>>, @@ -334,80 +338,91 @@ struct ClaudeAgentConnection { impl ClaudeAgentConnection { async fn handle_message( - delegate: OldAcpClientDelegate, + threads_map: Rc>>>, message: SdkMessage, end_turn_tx: Rc>>>>, - tool_id_map: Rc>>, + cx: &mut AsyncApp, ) { match message { - SdkMessage::Assistant { message, .. } | SdkMessage::User { message, .. } => { + SdkMessage::Assistant { + message, + session_id, + } + | SdkMessage::User { + message, + session_id, + } => { + let threads_map = threads_map.borrow(); + let Some(thread) = session_id + .and_then(|session_id| threads_map.get(&acp::SessionId(session_id.into()))) + .and_then(|entity| entity.upgrade()) + else { + log::error!("Thread not found for session"); + return; + }; for chunk in message.content.chunks() { match chunk { ContentChunk::Text { text } | ContentChunk::UntaggedText(text) => { - delegate - .stream_assistant_message_chunk( - acp_old::StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Text { text }, - }, - ) - .await + thread + .update(cx, |thread, cx| { + thread.push_assistant_chunk(text.into(), false, cx) + }) .log_err(); } ContentChunk::ToolUse { id, name, input } => { let claude_tool = ClaudeTool::infer(&name, input); - if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { - delegate - .update_plan(acp::UpdatePlanParams { - entries: params.todos.into_iter().map(Into::into).collect(), - }) - .await - .log_err(); - } else if let Some(resp) = delegate - .push_tool_call(claude_tool.as_acp()) - .await - .log_err() - { - tool_id_map.borrow_mut().insert(id, resp.id); - } + thread + .update(cx, |thread, cx| { + if let ClaudeTool::TodoWrite(Some(params)) = claude_tool { + thread.update_plan( + acp::Plan { + entries: params + .todos + .into_iter() + .map(Into::into) + .collect(), + }, + cx, + ) + } else { + thread.upsert_tool_call( + claude_tool.as_acp(acp::ToolCallId(id.into())), + cx, + ); + } + }) + .log_err(); } ContentChunk::ToolResult { content, tool_use_id, } => { - let id = tool_id_map.borrow_mut().remove(&tool_use_id); - if let Some(id) = id { - let content = content.to_string(); - delegate - .update_tool_call(acp_old::UpdateToolCallParams { - tool_call_id: id, - status: acp::ToolCallStatus::Finished, - // Don't unset existing content - content: (!content.is_empty()).then_some( - acp_old::ToolCallContent::Markdown { - // For now we only include text content - markdown: content, - }, - ), - }) - .await - .log_err(); - } + let content = content.to_string(); + thread + .update(cx, |thread, cx| { + thread.update_tool_call( + acp::ToolCallId(tool_use_id.into()), + acp::ToolCallStatus::Completed, + (!content.is_empty()).then(|| vec![content.into()]), + cx, + ) + }) + .log_err(); } ContentChunk::Image | ContentChunk::Document | ContentChunk::Thinking | ContentChunk::RedactedThinking | ContentChunk::WebSearchToolResult => { - delegate - .stream_assistant_message_chunk( - acp_old::StreamAssistantMessageChunkParams { - chunk: acp::AssistantMessageChunk::Text { - text: format!("Unsupported content: {:?}", chunk), - }, - }, - ) - .await + thread + .update(cx, |thread, cx| { + thread.push_assistant_chunk( + format!("Unsupported content: {:?}", chunk).into(), + false, + cx, + ) + }) .log_err(); } } @@ -591,14 +606,14 @@ enum SdkMessage { Assistant { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, // A user message User { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, // Emitted as the last message in a conversation diff --git a/crates/agent_servers/src/claude/tools.rs b/crates/agent_servers/src/claude/tools.rs index b84c78e2ccc5b1c1c16c04dd7c6c5257bb71c8fe..36b4661579cd878ee75bef0043944b7c08eb59ad 100644 --- a/crates/agent_servers/src/claude/tools.rs +++ b/crates/agent_servers/src/claude/tools.rs @@ -1,5 +1,6 @@ use std::path::PathBuf; +use agent_client_protocol as acp; use agentic_coding_protocol as acp_old; use itertools::Itertools; use schemars::JsonSchema; @@ -115,51 +116,36 @@ impl ClaudeTool { Self::Other { name, .. } => name.clone(), } } - - pub fn content(&self) -> Option { + pub fn content(&self) -> Vec { match &self { - Self::Other { input, .. } => Some(acp_old::ToolCallContent::Markdown { - markdown: format!( + Self::Other { input, .. } => vec![ + format!( "```json\n{}```", serde_json::to_string_pretty(&input).unwrap_or("{}".to_string()) - ), - }), - Self::Task(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params.prompt.clone(), - }), - Self::NotebookRead(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params.notebook_path.display().to_string(), - }), - Self::NotebookEdit(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params.new_source.clone(), - }), - Self::Terminal(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: format!( + ) + .into(), + ], + Self::Task(Some(params)) => vec![params.prompt.clone().into()], + Self::NotebookRead(Some(params)) => { + vec![params.notebook_path.display().to_string().into()] + } + Self::NotebookEdit(Some(params)) => vec![params.new_source.clone().into()], + Self::Terminal(Some(params)) => vec![ + format!( "`{}`\n\n{}", params.command, params.description.as_deref().unwrap_or_default() - ), - }), - Self::ReadFile(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params.abs_path.display().to_string(), - }), - Self::Ls(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params.path.display().to_string(), - }), - Self::Glob(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params.to_string(), - }), - Self::Grep(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: format!("`{params}`"), - }), - Self::WebFetch(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params.prompt.clone(), - }), - Self::WebSearch(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params.to_string(), - }), - Self::TodoWrite(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params + ) + .into(), + ], + Self::ReadFile(Some(params)) => vec![params.abs_path.display().to_string().into()], + Self::Ls(Some(params)) => vec![params.path.display().to_string().into()], + Self::Glob(Some(params)) => vec![params.to_string().into()], + Self::Grep(Some(params)) => vec![format!("`{params}`").into()], + Self::WebFetch(Some(params)) => vec![params.prompt.clone().into()], + Self::WebSearch(Some(params)) => vec![params.to_string().into()], + Self::TodoWrite(Some(params)) => vec![ + params .todos .iter() .map(|todo| { @@ -174,37 +160,39 @@ impl ClaudeTool { todo.content ) }) - .join("\n"), - }), - Self::ExitPlanMode(Some(params)) => Some(acp_old::ToolCallContent::Markdown { - markdown: params.plan.clone(), - }), - Self::Edit(Some(params)) => Some(acp_old::ToolCallContent::Diff { - diff: acp_old::Diff { + .join("\n") + .into(), + ], + Self::ExitPlanMode(Some(params)) => vec![params.plan.clone().into()], + Self::Edit(Some(params)) => vec![acp::ToolCallContent::Diff { + diff: acp::Diff { path: params.abs_path.clone(), old_text: Some(params.old_text.clone()), new_text: params.new_text.clone(), }, - }), - Self::Write(Some(params)) => Some(acp_old::ToolCallContent::Diff { - diff: acp_old::Diff { + }], + Self::Write(Some(params)) => vec![acp::ToolCallContent::Diff { + diff: acp::Diff { path: params.file_path.clone(), old_text: None, new_text: params.content.clone(), }, - }), + }], Self::MultiEdit(Some(params)) => { // todo: show multiple edits in a multibuffer? params .edits .first() - .map(|edit| acp_old::ToolCallContent::Diff { - diff: acp_old::Diff { - path: params.file_path.clone(), - old_text: Some(edit.old_string.clone()), - new_text: edit.new_string.clone(), - }, + .map(|edit| { + vec![acp::ToolCallContent::Diff { + diff: acp::Diff { + path: params.file_path.clone(), + old_text: Some(edit.old_string.clone()), + new_text: edit.new_string.clone(), + }, + }] }) + .unwrap_or_default() } Self::Task(None) | Self::NotebookRead(None) @@ -220,28 +208,28 @@ impl ClaudeTool { | Self::ExitPlanMode(None) | Self::Edit(None) | Self::Write(None) - | Self::MultiEdit(None) => None, + | Self::MultiEdit(None) => vec![], } } - pub fn icon(&self) -> acp_old::Icon { + pub fn kind(&self) -> acp::ToolKind { match self { - Self::Task(_) => acp_old::Icon::Hammer, - Self::NotebookRead(_) => acp_old::Icon::FileSearch, - Self::NotebookEdit(_) => acp_old::Icon::Pencil, - Self::Edit(_) => acp_old::Icon::Pencil, - Self::MultiEdit(_) => acp_old::Icon::Pencil, - Self::Write(_) => acp_old::Icon::Pencil, - Self::ReadFile(_) => acp_old::Icon::FileSearch, - Self::Ls(_) => acp_old::Icon::Folder, - Self::Glob(_) => acp_old::Icon::FileSearch, - Self::Grep(_) => acp_old::Icon::Regex, - Self::Terminal(_) => acp_old::Icon::Terminal, - Self::WebSearch(_) => acp_old::Icon::Globe, - Self::WebFetch(_) => acp_old::Icon::Globe, - Self::TodoWrite(_) => acp_old::Icon::LightBulb, - Self::ExitPlanMode(_) => acp_old::Icon::Hammer, - Self::Other { .. } => acp_old::Icon::Hammer, + Self::Task(_) => acp::ToolKind::Think, + Self::NotebookRead(_) => acp::ToolKind::Read, + Self::NotebookEdit(_) => acp::ToolKind::Edit, + Self::Edit(_) => acp::ToolKind::Edit, + Self::MultiEdit(_) => acp::ToolKind::Edit, + Self::Write(_) => acp::ToolKind::Edit, + Self::ReadFile(_) => acp::ToolKind::Read, + Self::Ls(_) => acp::ToolKind::Search, + Self::Glob(_) => acp::ToolKind::Search, + Self::Grep(_) => acp::ToolKind::Search, + Self::Terminal(_) => acp::ToolKind::Execute, + Self::WebSearch(_) => acp::ToolKind::Search, + Self::WebFetch(_) => acp::ToolKind::Fetch, + Self::TodoWrite(_) => acp::ToolKind::Think, + Self::ExitPlanMode(_) => acp::ToolKind::Think, + Self::Other { .. } => acp::ToolKind::Other, } } @@ -348,55 +336,55 @@ impl ClaudeTool { } } - pub fn locations(&self) -> Vec { + pub fn locations(&self) -> Vec { match &self { - Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp_old::ToolCallLocation { + Self::Edit(Some(EditToolParams { abs_path, .. })) => vec![acp::ToolCallLocation { path: abs_path.clone(), line: None, }], Self::MultiEdit(Some(MultiEditToolParams { file_path, .. })) => { - vec![acp_old::ToolCallLocation { + vec![acp::ToolCallLocation { path: file_path.clone(), line: None, }] } Self::Write(Some(WriteToolParams { file_path, .. })) => { - vec![acp_old::ToolCallLocation { + vec![acp::ToolCallLocation { path: file_path.clone(), line: None, }] } Self::ReadFile(Some(ReadToolParams { abs_path, offset, .. - })) => vec![acp_old::ToolCallLocation { + })) => vec![acp::ToolCallLocation { path: abs_path.clone(), line: *offset, }], Self::NotebookRead(Some(NotebookReadToolParams { notebook_path, .. })) => { - vec![acp_old::ToolCallLocation { + vec![acp::ToolCallLocation { path: notebook_path.clone(), line: None, }] } Self::NotebookEdit(Some(NotebookEditToolParams { notebook_path, .. })) => { - vec![acp_old::ToolCallLocation { + vec![acp::ToolCallLocation { path: notebook_path.clone(), line: None, }] } Self::Glob(Some(GlobToolParams { path: Some(path), .. - })) => vec![acp_old::ToolCallLocation { + })) => vec![acp::ToolCallLocation { path: path.clone(), line: None, }], - Self::Ls(Some(LsToolParams { path, .. })) => vec![acp_old::ToolCallLocation { + Self::Ls(Some(LsToolParams { path, .. })) => vec![acp::ToolCallLocation { path: path.clone(), line: None, }], Self::Grep(Some(GrepToolParams { path: Some(path), .. - })) => vec![ToolCallLocation { + })) => vec![acp::ToolCallLocation { path: PathBuf::from(path), line: None, }], @@ -419,11 +407,13 @@ impl ClaudeTool { } } - pub fn as_acp(&self) -> acp_old::PushToolCallParams { - acp_old::PushToolCallParams { + pub fn as_acp(&self, id: acp::ToolCallId) -> acp::ToolCall { + acp::ToolCall { + id, + kind: self.kind(), + status: acp::ToolCallStatus::InProgress, label: self.label(), content: self.content(), - icon: self.icon(), locations: self.locations(), } } @@ -609,12 +599,12 @@ pub enum TodoPriority { Low, } -impl Into for TodoPriority { - fn into(self) -> acp_old::PlanEntryPriority { +impl Into for TodoPriority { + fn into(self) -> acp::PlanEntryPriority { match self { - TodoPriority::High => acp_old::PlanEntryPriority::High, - TodoPriority::Medium => acp_old::PlanEntryPriority::Medium, - TodoPriority::Low => acp_old::PlanEntryPriority::Low, + TodoPriority::High => acp::PlanEntryPriority::High, + TodoPriority::Medium => acp::PlanEntryPriority::Medium, + TodoPriority::Low => acp::PlanEntryPriority::Low, } } } @@ -627,12 +617,12 @@ pub enum TodoStatus { Completed, } -impl Into for TodoStatus { - fn into(self) -> acp_old::PlanEntryStatus { +impl Into for TodoStatus { + fn into(self) -> acp::PlanEntryStatus { match self { - TodoStatus::Pending => acp_old::PlanEntryStatus::Pending, - TodoStatus::InProgress => acp_old::PlanEntryStatus::InProgress, - TodoStatus::Completed => acp_old::PlanEntryStatus::Completed, + TodoStatus::Pending => acp::PlanEntryStatus::Pending, + TodoStatus::InProgress => acp::PlanEntryStatus::InProgress, + TodoStatus::Completed => acp::PlanEntryStatus::Completed, } } } @@ -649,9 +639,9 @@ pub struct Todo { pub status: TodoStatus, } -impl Into for Todo { - fn into(self) -> acp_old::PlanEntry { - acp_old::PlanEntry { +impl Into for Todo { + fn into(self) -> acp::PlanEntry { + acp::PlanEntry { content: self.content, priority: self.priority.into(), status: self.status.into(), diff --git a/crates/agent_servers/src/mcp_server.rs b/crates/agent_servers/src/mcp_server.rs index 628f56df85b8b8d684191b9f6b74d438c8a3eca4..0fb94c46bcc3b8ed2bb63620089deba6689d56bd 100644 --- a/crates/agent_servers/src/mcp_server.rs +++ b/crates/agent_servers/src/mcp_server.rs @@ -1,6 +1,7 @@ use std::{cell::RefCell, path::PathBuf, rc::Rc}; -use acp_thread::OldAcpClientDelegate; +use acp_thread::{AcpThread, OldAcpClientDelegate}; +use agent_client_protocol::{self as acp}; use agentic_coding_protocol::{self as acp_old, Client as _}; use anyhow::{Context, Result}; use collections::HashMap; @@ -52,7 +53,7 @@ enum PermissionToolBehavior { impl ZedMcpServer { pub async fn new( - delegate: watch::Receiver>, + thread_map: Rc>>>, tool_id_map: Rc>>, cx: &AsyncApp, ) -> Result { @@ -60,7 +61,7 @@ impl ZedMcpServer { mcp_server.handle_request::(Self::handle_initialize); mcp_server.handle_request::(Self::handle_list_tools); mcp_server.handle_request::(move |request, cx| { - Self::handle_call_tool(request, delegate.clone(), tool_id_map.clone(), cx) + Self::handle_call_tool(request, thread_map.clone(), tool_id_map.clone(), cx) }); Ok(Self { server: mcp_server }) diff --git a/crates/agent_servers/src/stdio_agent_server.rs b/crates/agent_servers/src/stdio_agent_server.rs index 6c6f52519d0136481748cbe6baa8df76ed013497..13d6f4f6924861c370900473d11c03411aff84ca 100644 --- a/crates/agent_servers/src/stdio_agent_server.rs +++ b/crates/agent_servers/src/stdio_agent_server.rs @@ -1,10 +1,10 @@ use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; -use acp_thread::{AcpThread, LoadError, OldAcpClientDelegate}; +use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; use agentic_coding_protocol as acp_old; use anyhow::{Result, anyhow}; -use gpui::{App, AsyncApp, Entity, Task, prelude::*}; +use gpui::{App, AsyncApp, Entity, Task, WeakEntity, prelude::*}; use project::Project; -use std::path::Path; +use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc}; use util::ResultExt; pub trait StdioAgentServer: Send + Clone { @@ -47,16 +47,15 @@ impl AgentServer for T { self.supports_always_allow() } - fn new_thread( + fn connect( &self, root_dir: &Path, project: &Entity, cx: &mut App, - ) -> Task>> { + ) -> Task>> { let root_dir = root_dir.to_path_buf(); let project = project.clone(); let this = self.clone(); - let title = self.name().into(); cx.spawn(async move |cx| { let command = this.command(&project, cx).await?; @@ -73,47 +72,53 @@ impl AgentServer for T { let stdin = child.stdin.take().unwrap(); let stdout = child.stdout.take().unwrap(); - cx.new(|cx| { - let foreground_executor = cx.foreground_executor().clone(); - - let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(AgentServerVersion::Unsupported { + let foreground_executor = cx.foreground_executor().clone(); + + let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => { + if let Some(AgentServerVersion::Unsupported { + error_message, + upgrade_message, + upgrade_command, + }) = this.version(&command).await.log_err() + { + Err(anyhow!(LoadError::Unsupported { error_message, upgrade_message, - upgrade_command, - }) = this.version(&command).await.log_err() - { - Err(anyhow!(LoadError::Unsupported { - error_message, - upgrade_message, - upgrade_command - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } + upgrade_command + })) + } else { + Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) } - }; - drop(io_task); - result - }); - - AcpThread::new(connection, title, Some(child_status), project.clone(), cx) - }) + } + }; + drop(io_task); + result + }); + + let connection: Arc = Arc::new(OldAcpAgentConnection { + connection, + child_status, + thread: thread_rc, + }); + + Ok(connection) }) } } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 0b4f3325fc1e6f4cda2428a7c1e4eabceb108d87..0dfb6fef28cd6ae24d9d13157ae27ab5cb913b8e 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -209,9 +209,9 @@ impl AcpThreadView { .map(|worktree| worktree.read(cx).abs_path()) .unwrap_or_else(|| paths::home_dir().as_path().into()); - let task = agent.new_thread(&root_dir, &project, cx); + let connect_task = agent.connect(&root_dir, &project, cx); let load_task = cx.spawn_in(window, async move |this, cx| { - let thread = match task.await { + let connection = match task.await { Ok(thread) => thread, Err(err) => { this.update(cx, |this, cx| { @@ -223,14 +223,10 @@ impl AcpThreadView { } }; - let init_response = async { - let resp = thread - .read_with(cx, |thread, _cx| thread.initialize())? - .await?; - anyhow::Ok(resp) - }; - - let result = match init_response.await { + let result = match connection + .new_thread(&project, root_dir, connection.clone(), cx) + .await + { Err(e) => { let mut cx = cx.clone(); if e.downcast_ref::().is_some() { @@ -246,25 +242,31 @@ impl AcpThreadView { } else { Err(e) } - } else { - Err(e) - } - } - Ok(response) => { - if !response.is_authenticated { + } else if e.downcast_ref::().is_some() { this.update(cx, |this, _| { this.thread_state = ThreadState::Unauthenticated { thread }; }) .ok(); return; - }; - Ok(()) + } else { + Err(e) + } } + Ok(session_id) => Ok(session_id), }; this.update_in(cx, |this, window, cx| { match result { - Ok(()) => { + Ok(session_id) => { + let thread = AcpThread::new( + connection, + agent.title(), + None, + project.clone(), + cx, + session_id, + ); + let thread_subscription = cx.subscribe_in(&thread, window, Self::handle_thread_event);