Stream deserialized thread to `AcpThread`

Antonio Scandurra created

Change summary

crates/agent2/src/agent.rs  | 40 ++++++++++++++++++++++++++--------
crates/agent2/src/thread.rs | 44 +++++++++++++++++++++++++++++---------
2 files changed, 63 insertions(+), 21 deletions(-)

Detailed changes

crates/agent2/src/agent.rs 🔗

@@ -5,7 +5,7 @@ use crate::{
     OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread, ThreadEvent, ToolCallAuthorization,
     UserMessageContent, WebSearchTool, templates::Templates,
 };
-use crate::{DbThread, ThreadsDatabase};
+use crate::{DbThread, ThreadId, ThreadsDatabase};
 use acp_thread::{AcpThread, AcpThreadMetadata, AgentModelSelector};
 use agent_client_protocol as acp;
 use agent_settings::AgentSettings;
@@ -473,10 +473,18 @@ impl NativeAgentConnection {
         };
         log::debug!("Found session for: {}", session_id);
 
-        let mut response_stream = match f(thread, cx) {
+        let response_stream = match f(thread, cx) {
             Ok(stream) => stream,
             Err(err) => return Task::ready(Err(err)),
         };
+        Self::handle_thread_events(response_stream, acp_thread, cx)
+    }
+
+    fn handle_thread_events(
+        mut response_stream: mpsc::UnboundedReceiver<Result<ThreadEvent>>,
+        acp_thread: WeakEntity<AcpThread>,
+        cx: &mut App,
+    ) -> Task<Result<acp::PromptResponse>> {
         cx.spawn(async move |cx| {
             // Handle response stream and forward to session.acp_thread
             while let Some(result) = response_stream.next().await {
@@ -486,7 +494,15 @@ impl NativeAgentConnection {
 
                         match event {
                             ThreadEvent::UserMessage(message) => {
-                                todo!()
+                                acp_thread.update(cx, |thread, cx| {
+                                    for content in message.content {
+                                        thread.push_user_content_block(
+                                            Some(message.id.clone()),
+                                            content.into(),
+                                            cx,
+                                        );
+                                    }
+                                })?;
                             }
                             ThreadEvent::AgentText(text) => {
                                 acp_thread.update(cx, |thread, cx| {
@@ -806,19 +822,19 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         session_id: acp::SessionId,
         cx: &mut App,
     ) -> Task<Result<Entity<acp_thread::AcpThread>>> {
-        let thread_id = session_id.clone().into();
+        let thread_id = ThreadId::from(session_id.clone());
         let database = self.0.update(cx, |this, _| this.thread_database.clone());
         cx.spawn(async move |cx| {
             let database = database.await.map_err(|e| anyhow!(e))?;
             let db_thread = database
-                .load_thread(thread_id)
+                .load_thread(thread_id.clone())
                 .await?
                 .context("no such thread found")?;
 
             let acp_thread = cx.update(|cx| {
                 cx.new(|cx| {
                     acp_thread::AcpThread::new(
-                        db_thread.title,
+                        db_thread.title.clone(),
                         self.clone(),
                         project.clone(),
                         session_id.clone(),
@@ -835,6 +851,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                     .update(cx, |registry, cx| {
                         db_thread
                             .model
+                            .as_ref()
                             .and_then(|model| {
                                 let model = SelectedModel {
                                     provider: model.provider.clone().into(),
@@ -852,7 +869,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                     .context("no model by id")?;
 
                 let thread = cx.new(|cx| {
-                    let mut thread = Thread::new(
+                    let mut thread = Thread::from_db(
+                        thread_id,
+                        db_thread,
                         project.clone(),
                         agent.project_context.clone(),
                         agent.context_server_registry.clone(),
@@ -873,7 +892,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                 agent.sessions.insert(
                     session_id,
                     Session {
-                        thread,
+                        thread: thread.clone(),
                         acp_thread: acp_thread.downgrade(),
                         _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
                             this.sessions.remove(acp_thread.session_id());
@@ -882,8 +901,9 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                 );
             })?;
 
-            // we need to actually deserialize the DbThread.
-            // todo!()
+            let events = thread.update(cx, |thread, cx| thread.replay(cx))?;
+            cx.update(|cx| Self::handle_thread_events(events, acp_thread.downgrade(), cx))?
+                .await?;
 
             Ok(acp_thread)
         })

crates/agent2/src/thread.rs 🔗

@@ -12,7 +12,7 @@ use futures::{
     channel::{mpsc, oneshot},
     stream::FuturesUnordered,
 };
-use gpui::{App, Context, Entity, SharedString, Task};
+use gpui::{App, AppContext, Context, Entity, SharedString, Task};
 use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
     LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
@@ -545,7 +545,10 @@ impl Thread {
         }
     }
 
-    pub fn replay(&self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
+    pub fn replay(
+        &mut self,
+        cx: &mut Context<Self>,
+    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
         let (tx, rx) = mpsc::unbounded();
         let stream = ThreadEventStream(tx);
         for message in &self.messages {
@@ -615,16 +618,15 @@ impl Thread {
             );
             tool.replay(tool_use.input.clone(), output, tool_event_stream, cx)
                 .log_err();
-        } else {
-            stream.update_tool_call_fields(
-                &tool_use.id,
-                acp::ToolCallUpdateFields {
-                    content: Some(vec![TOOL_CANCELED_MESSAGE.into()]),
-                    status: Some(acp::ToolCallStatus::Failed),
-                    ..Default::default()
-                },
-            );
         }
+
+        stream.update_tool_call_fields(
+            &tool_use.id,
+            acp::ToolCallUpdateFields {
+                status: Some(acp::ToolCallStatus::Completed),
+                ..Default::default()
+            },
+        );
     }
 
     pub fn project(&self) -> &Entity<Project> {
@@ -1744,6 +1746,26 @@ impl From<acp::ContentBlock> for UserMessageContent {
     }
 }
 
+impl From<UserMessageContent> for acp::ContentBlock {
+    fn from(content: UserMessageContent) -> Self {
+        match content {
+            UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent {
+                text,
+                annotations: None,
+            }),
+            UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent {
+                data: image.source.to_string(),
+                mime_type: "image/png".to_string(),
+                annotations: None,
+                uri: None,
+            }),
+            UserMessageContent::Mention { uri, content } => {
+                todo!()
+            }
+        }
+    }
+}
+
 fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
     LanguageModelImage {
         source: image_content.data.into(),