Fix mentions roundtrip from/to database and other history bugs (#36575)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/agent2/src/agent.rs  | 170 ++++++++++++++++++++++++++++++++++++++
crates/agent2/src/thread.rs |  58 +++++++------
2 files changed, 200 insertions(+), 28 deletions(-)

Detailed changes

crates/agent2/src/agent.rs 🔗

@@ -577,6 +577,10 @@ impl NativeAgent {
     }
 
     fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
+        if thread.read(cx).is_empty() {
+            return;
+        }
+
         let database_future = ThreadsDatabase::connect(cx);
         let (id, db_thread) =
             thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
@@ -989,12 +993,19 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
 
 #[cfg(test)]
 mod tests {
+    use crate::HistoryEntryId;
+
     use super::*;
-    use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
+    use acp_thread::{
+        AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
+    };
     use fs::FakeFs;
     use gpui::TestAppContext;
+    use indoc::indoc;
+    use language_model::fake_provider::FakeLanguageModel;
     use serde_json::json;
     use settings::SettingsStore;
+    use util::path;
 
     #[gpui::test]
     async fn test_maintaining_project_context(cx: &mut TestAppContext) {
@@ -1179,6 +1190,163 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
+    async fn test_save_load_thread(cx: &mut TestAppContext) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/",
+            json!({
+                "a": {
+                    "b.md": "Lorem"
+                }
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+        let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
+        let agent = NativeAgent::new(
+            project.clone(),
+            history_store.clone(),
+            Templates::new(),
+            None,
+            fs.clone(),
+            &mut cx.to_async(),
+        )
+        .await
+        .unwrap();
+        let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+        let acp_thread = cx
+            .update(|cx| {
+                connection
+                    .clone()
+                    .new_thread(project.clone(), Path::new(""), cx)
+            })
+            .await
+            .unwrap();
+        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+        let thread = agent.read_with(cx, |agent, _| {
+            agent.sessions.get(&session_id).unwrap().thread.clone()
+        });
+
+        // Ensure empty threads are not saved, even if they get mutated.
+        let model = Arc::new(FakeLanguageModel::default());
+        let summary_model = Arc::new(FakeLanguageModel::default());
+        thread.update(cx, |thread, cx| {
+            thread.set_model(model, cx);
+            thread.set_summarization_model(Some(summary_model), cx);
+        });
+        cx.run_until_parked();
+        assert_eq!(history_entries(&history_store, cx), vec![]);
+
+        let model = thread.read_with(cx, |thread, _| thread.model().unwrap().clone());
+        let model = model.as_fake();
+        let summary_model = thread.read_with(cx, |thread, _| {
+            thread.summarization_model().unwrap().clone()
+        });
+        let summary_model = summary_model.as_fake();
+        let send = acp_thread.update(cx, |thread, cx| {
+            thread.send(
+                vec![
+                    "What does ".into(),
+                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
+                        name: "b.md".into(),
+                        uri: MentionUri::File {
+                            abs_path: path!("/a/b.md").into(),
+                        }
+                        .to_uri()
+                        .to_string(),
+                        annotations: None,
+                        description: None,
+                        mime_type: None,
+                        size: None,
+                        title: None,
+                    }),
+                    " mean?".into(),
+                ],
+                cx,
+            )
+        });
+        let send = cx.foreground_executor().spawn(send);
+        cx.run_until_parked();
+
+        model.send_last_completion_stream_text_chunk("Lorem.");
+        model.end_last_completion_stream();
+        cx.run_until_parked();
+        summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
+        summary_model.end_last_completion_stream();
+
+        send.await.unwrap();
+        acp_thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc! {"
+                    ## User
+
+                    What does [@b.md](file:///a/b.md) mean?
+
+                    ## Assistant
+
+                    Lorem.
+
+                "}
+            )
+        });
+
+        // Drop the ACP thread, which should cause the session to be dropped as well.
+        cx.update(|_| {
+            drop(thread);
+            drop(acp_thread);
+        });
+        agent.read_with(cx, |agent, _| {
+            assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
+        });
+
+        // Ensure the thread can be reloaded from disk.
+        assert_eq!(
+            history_entries(&history_store, cx),
+            vec![(
+                HistoryEntryId::AcpThread(session_id.clone()),
+                "Explaining /a/b.md".into()
+            )]
+        );
+        let acp_thread = agent
+            .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
+            .await
+            .unwrap();
+        acp_thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc! {"
+                    ## User
+
+                    What does [@b.md](file:///a/b.md) mean?
+
+                    ## Assistant
+
+                    Lorem.
+
+                "}
+            )
+        });
+    }
+
+    fn history_entries(
+        history: &Entity<HistoryStore>,
+        cx: &mut TestAppContext,
+    ) -> Vec<(HistoryEntryId, String)> {
+        history.read_with(cx, |history, cx| {
+            history
+                .entries(cx)
+                .iter()
+                .map(|e| (e.id(), e.title().to_string()))
+                .collect::<Vec<_>>()
+        })
+    }
+
     fn init_test(cx: &mut TestAppContext) {
         env_logger::try_init().ok();
         cx.update(|cx| {

crates/agent2/src/thread.rs 🔗

@@ -720,7 +720,7 @@ impl Thread {
     pub fn to_db(&self, cx: &App) -> Task<DbThread> {
         let initial_project_snapshot = self.initial_project_snapshot.clone();
         let mut thread = DbThread {
-            title: self.title.clone().unwrap_or_default(),
+            title: self.title(),
             messages: self.messages.clone(),
             updated_at: self.updated_at,
             detailed_summary: self.summary.clone(),
@@ -870,6 +870,10 @@ impl Thread {
         &self.action_log
     }
 
+    pub fn is_empty(&self) -> bool {
+        self.messages.is_empty() && self.title.is_none()
+    }
+
     pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
         self.model.as_ref()
     }
@@ -884,6 +888,10 @@ impl Thread {
         cx.notify()
     }
 
+    pub fn summarization_model(&self) -> Option<&Arc<dyn LanguageModel>> {
+        self.summarization_model.as_ref()
+    }
+
     pub fn set_summarization_model(
         &mut self,
         model: Option<Arc<dyn LanguageModel>>,
@@ -1068,6 +1076,7 @@ impl Thread {
             event_stream: event_stream.clone(),
             _task: cx.spawn(async move |this, cx| {
                 log::info!("Starting agent turn execution");
+                let mut update_title = None;
                 let turn_result: Result<StopReason> = async {
                     let mut completion_intent = CompletionIntent::UserPrompt;
                     loop {
@@ -1122,10 +1131,15 @@ impl Thread {
                                 this.pending_message()
                                     .tool_results
                                     .insert(tool_result.tool_use_id.clone(), tool_result);
-                            })
-                            .ok();
+                            })?;
                         }
 
+                        this.update(cx, |this, cx| {
+                            if this.title.is_none() && update_title.is_none() {
+                                update_title = Some(this.update_title(&event_stream, cx));
+                            }
+                        })?;
+
                         if tool_use_limit_reached {
                             log::info!("Tool use limit reached, completing turn");
                             this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
@@ -1146,10 +1160,6 @@ impl Thread {
                     Ok(reason) => {
                         log::info!("Turn execution completed: {:?}", reason);
 
-                        let update_title = this
-                            .update(cx, |this, cx| this.update_title(&event_stream, cx))
-                            .ok()
-                            .flatten();
                         if let Some(update_title) = update_title {
                             update_title.await.context("update title failed").log_err();
                         }
@@ -1593,17 +1603,14 @@ impl Thread {
         &mut self,
         event_stream: &ThreadEventStream,
         cx: &mut Context<Self>,
-    ) -> Option<Task<Result<()>>> {
-        if self.title.is_some() {
-            log::debug!("Skipping title generation because we already have one.");
-            return None;
-        }
-
+    ) -> Task<Result<()>> {
         log::info!(
             "Generating title with model: {:?}",
             self.summarization_model.as_ref().map(|model| model.name())
         );
-        let model = self.summarization_model.clone()?;
+        let Some(model) = self.summarization_model.clone() else {
+            return Task::ready(Ok(()));
+        };
         let event_stream = event_stream.clone();
         let mut request = LanguageModelRequest {
             intent: Some(CompletionIntent::ThreadSummarization),
@@ -1620,7 +1627,7 @@ impl Thread {
             content: vec![SUMMARIZE_THREAD_PROMPT.into()],
             cache: false,
         });
-        Some(cx.spawn(async move |this, cx| {
+        cx.spawn(async move |this, cx| {
             let mut title = String::new();
             let mut messages = model.stream_completion(request, cx).await?;
             while let Some(event) = messages.next().await {
@@ -1655,7 +1662,7 @@ impl Thread {
                 this.title = Some(title);
                 cx.notify();
             })
-        }))
+        })
     }
 
     fn last_user_message(&self) -> Option<&UserMessage> {
@@ -2457,18 +2464,15 @@ impl From<UserMessageContent> for acp::ContentBlock {
                 uri: None,
             }),
             UserMessageContent::Mention { uri, content } => {
-                acp::ContentBlock::ResourceLink(acp::ResourceLink {
-                    uri: uri.to_uri().to_string(),
-                    name: uri.name(),
+                acp::ContentBlock::Resource(acp::EmbeddedResource {
+                    resource: acp::EmbeddedResourceResource::TextResourceContents(
+                        acp::TextResourceContents {
+                            mime_type: None,
+                            text: content,
+                            uri: uri.to_uri().to_string(),
+                        },
+                    ),
                     annotations: None,
-                    description: if content.is_empty() {
-                        None
-                    } else {
-                        Some(content)
-                    },
-                    mime_type: None,
-                    size: None,
-                    title: None,
                 })
             }
         }