Implement the feature: retain draft threads instead of resetting them

Mikayla Maki created

Change summary

crates/agent/src/agent.rs                    |  13 +
crates/agent/src/native_agent_server.rs      |  11 -
crates/agent/src/thread.rs                   |   4 
crates/agent_ui/src/thread_metadata_store.rs |  13 +
crates/sidebar/src/sidebar.rs                |  86 ++++++++++--
crates/sidebar/src/sidebar_tests.rs          | 153 ++++++++++++++++-----
6 files changed, 212 insertions(+), 68 deletions(-)

Detailed changes

crates/agent/src/agent.rs 🔗

@@ -974,15 +974,20 @@ impl NativeAgent {
     }
 
     fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
-        if thread.read(cx).is_empty() {
-            return;
-        }
-
         let id = thread.read(cx).id().clone();
         let Some(session) = self.sessions.get_mut(&id) else {
             return;
         };
 
+        let has_draft_prompt = session
+            .acp_thread
+            .read(cx)
+            .draft_prompt()
+            .is_some_and(|p| !p.is_empty());
+        if thread.read(cx).is_empty() && !has_draft_prompt {
+            return;
+        }
+
         let project_id = session.project_id;
         let Some(state) = self.projects.get(&project_id) else {
             return;

crates/agent/src/native_agent_server.rs 🔗

@@ -1,4 +1,5 @@
 use std::{any::Any, rc::Rc, sync::Arc};
+use util::ResultExt as _;
 
 use agent_client_protocol as acp;
 use agent_servers::{AgentServer, AgentServerDelegate};
@@ -45,17 +46,13 @@ impl AgentServer for NativeAgentServer {
         let thread_store = self.thread_store.clone();
         let prompt_store = PromptStore::global(cx);
         cx.spawn(async move |cx| {
-            log::debug!("Creating templates for native agent");
             let templates = Templates::new();
-            let prompt_store = prompt_store.await?;
+            let prompt_store = prompt_store.await.log_err();
 
-            log::debug!("Creating native agent entity");
-            let agent = cx
-                .update(|cx| NativeAgent::new(thread_store, templates, Some(prompt_store), fs, cx));
+            let agent =
+                cx.update(|cx| NativeAgent::new(thread_store, templates, prompt_store, fs, cx));
 
-            // Create the connection wrapper
             let connection = NativeAgentConnection(agent);
-            log::debug!("NativeAgentServer connection established successfully");
 
             Ok(Rc::new(connection) as Rc<dyn acp_thread::AgentConnection>)
         })

crates/agent/src/thread.rs 🔗

@@ -1373,7 +1373,9 @@ impl Thread {
     }
 
     pub fn is_empty(&self) -> bool {
-        self.messages.is_empty() && self.title.is_none()
+        self.messages.is_empty()
+            && self.title.is_none()
+            && self.draft_prompt.as_ref().is_none_or(|p| p.is_empty())
     }
 
     pub fn draft_prompt(&self) -> Option<&[acp::ContentBlock]> {

crates/agent_ui/src/thread_metadata_store.rs 🔗

@@ -198,7 +198,12 @@ impl ThreadMetadataStore {
     pub fn init_global(cx: &mut App) {
         let thread = std::thread::current();
         let test_name = thread.name().unwrap_or("unknown_test");
-        let db_name = format!("THREAD_METADATA_DB_{}", test_name);
+        Self::init_global_with_name(test_name, cx);
+    }
+
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn init_global_with_name(name: &str, cx: &mut App) {
+        let db_name = format!("THREAD_METADATA_DB_{}", name);
         let db = smol::block_on(db::open_test_db::<ThreadMetadataDb>(&db_name));
         let thread_store = cx.new(|cx| Self::new(ThreadMetadataDb(db), cx));
         cx.set_global(GlobalThreadMetadataStore(thread_store));
@@ -364,9 +369,9 @@ impl ThreadMetadataStore {
                         .update(cx, |store, cx| {
                             let session_id = thread.session_id().clone();
                             store.session_subscriptions.remove(&session_id);
-                            if thread.entries().is_empty() {
-                                // Empty threads can be unloaded without ever being
-                                // durably persisted by the underlying agent.
+                            let is_blank = thread.entries().is_empty()
+                                && thread.draft_prompt().is_none_or(|p| p.is_empty());
+                            if is_blank {
                                 store.delete(session_id, cx);
                             }
                         })

crates/sidebar/src/sidebar.rs 🔗

@@ -700,10 +700,22 @@ impl Sidebar {
                 if panel.read(cx).active_thread_is_draft(cx)
                     || panel.read(cx).active_conversation_view().is_none()
                 {
-                    let preserving_thread =
-                        matches!(&self.active_entry, Some(ActiveEntry::Thread { .. }))
-                            && self.active_entry_workspace() == Some(active_ws);
-                    if !preserving_thread {
+                    // When the sidebar eagerly sets active_entry to a Thread
+                    // (e.g. via activate_thread_locally), the panel may
+                    // temporarily report as a draft while the conversation
+                    // is still loading. Don't overwrite the Thread entry in
+                    // that case — unless the thread has since been archived.
+                    let thread_is_loading =
+                        if let Some(ActiveEntry::Thread { session_id, .. }) = &self.active_entry {
+                            self.active_entry_workspace() == Some(active_ws)
+                                && !ThreadMetadataStore::global(cx)
+                                    .read(cx)
+                                    .entry(session_id)
+                                    .is_some_and(|m| m.archived)
+                        } else {
+                            false
+                        };
+                    if !thread_is_loading {
                         let draft_session_id = panel
                             .read(cx)
                             .active_conversation_view()
@@ -803,7 +815,7 @@ impl Sidebar {
             let mut waiting_thread_count: usize = 0;
 
             if should_load_threads {
-                let mut seen_session_ids: HashSet<acp::SessionId> = HashSet::new();
+                let mut seen_session_ids: HashSet<acp::SessionId> = HashSet::default();
                 let thread_store = ThreadMetadataStore::global(cx);
 
                 // Load threads from each workspace in the group.
@@ -1523,7 +1535,12 @@ impl Sidebar {
                                     // the new-thread entry becomes visible.
                                     this.collapsed_groups.remove(&path_list_for_new_thread);
                                     this.selection = None;
-                                    this.create_new_thread(&workspace_for_new_thread, window, cx);
+                                    this.create_new_thread(
+                                        &workspace_for_new_thread,
+                                        None,
+                                        window,
+                                        cx,
+                                    );
                                 }
                             })),
                         )
@@ -2002,9 +2019,16 @@ impl Sidebar {
                 self.serialize(cx);
                 self.update_entries(cx);
             }
-            ListEntry::NewThread { workspace, .. } => {
+            ListEntry::NewThread {
+                workspace,
+                draft_thread,
+                ..
+            } => {
                 let workspace = workspace.clone();
-                self.create_new_thread(&workspace, window, cx);
+                let draft_session_id = draft_thread
+                    .as_ref()
+                    .map(|t| t.read(cx).session_id().clone());
+                self.create_new_thread(&workspace, draft_session_id, window, cx);
             }
         }
     }
@@ -3056,12 +3080,13 @@ impl Sidebar {
             return;
         };
 
-        self.create_new_thread(&workspace, window, cx);
+        self.create_new_thread(&workspace, None, window, cx);
     }
 
     fn create_new_thread(
         &mut self,
         workspace: &Entity<Workspace>,
+        draft_session_id: Option<acp::SessionId>,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
@@ -3069,7 +3094,10 @@ impl Sidebar {
             return;
         };
 
-        self.active_entry = Some(ActiveEntry::draft_for_workspace(workspace.clone()));
+        self.active_entry = Some(ActiveEntry::Draft {
+            session_id: draft_session_id.clone(),
+            workspace: workspace.clone(),
+        });
 
         multi_workspace.update(cx, |multi_workspace, cx| {
             multi_workspace.activate(workspace.clone(), window, cx);
@@ -3078,7 +3106,19 @@ impl Sidebar {
         workspace.update(cx, |workspace, cx| {
             if let Some(agent_panel) = workspace.panel::<AgentPanel>(cx) {
                 agent_panel.update(cx, |panel, cx| {
-                    panel.new_thread(&NewThread, window, cx);
+                    if let Some(session_id) = draft_session_id {
+                        panel.load_agent_thread(
+                            Agent::NativeAgent,
+                            session_id,
+                            None,
+                            None,
+                            true,
+                            window,
+                            cx,
+                        );
+                    } else {
+                        panel.new_thread(&NewThread, window, cx);
+                    }
                 });
             }
             workspace.focus_panel::<AgentPanel>(window, cx);
@@ -3101,6 +3141,7 @@ impl Sidebar {
             .unwrap_or_else(|| DEFAULT_THREAD_TITLE.into());
 
         let workspace = workspace.clone();
+        let draft_session_id = draft_thread.map(|thread| thread.read(cx).session_id().clone());
         let id = SharedString::from(format!("new-thread-btn-{}", ix));
 
         let thread_item = ThreadItem::new(id, label)
@@ -3121,7 +3162,7 @@ impl Sidebar {
             .when(!is_active, |this| {
                 this.on_click(cx.listener(move |this, _, window, cx| {
                     this.selection = None;
-                    this.create_new_thread(&workspace, window, cx);
+                    this.create_new_thread(&workspace, draft_session_id.clone(), window, cx);
                 }))
             });
 
@@ -3636,8 +3677,25 @@ fn summarize_content_blocks(blocks: &[acp::ContentBlock]) -> Option<SharedString
             acp::ContentBlock::ResourceLink(link) => {
                 text.push_str(&format!("@{}", link.name));
             }
-            acp::ContentBlock::Image(_) => {
-                text.push_str("[image]");
+            acp::ContentBlock::Resource(resource) => {
+                if let acp::EmbeddedResourceResource::TextResourceContents(
+                    acp::TextResourceContents { uri, .. },
+                ) = &resource.resource
+                {
+                    let name = uri.rsplit('/').next().unwrap_or(uri);
+                    text.push_str(&format!("@{}", name));
+                }
+            }
+            acp::ContentBlock::Image(image) => {
+                let name = image
+                    .uri
+                    .as_ref()
+                    .map(|uri| uri.rsplit('/').next().unwrap_or(uri))
+                    .unwrap_or(&image.mime_type);
+                text.push_str(&format!("@{}", name));
+            }
+            agent_client_protocol::ContentBlock::Audio(audio) => {
+                text.push_str(&format!("@{}", audio.mime_type));
             }
             _ => {}
         }

crates/sidebar/src/sidebar_tests.rs 🔗

@@ -154,6 +154,56 @@ async fn save_thread_metadata(
     cx.run_until_parked();
 }
 
+fn save_thread_with_content(
+    session_id: &acp::SessionId,
+    path_list: PathList,
+    cx: &mut gpui::VisualTestContext,
+) {
+    let title: SharedString = "Test".into();
+    let updated_at = chrono::TimeZone::with_ymd_and_hms(&chrono::Utc, 2024, 1, 1, 0, 0, 0).unwrap();
+    let metadata = ThreadMetadata {
+        session_id: session_id.clone(),
+        agent_id: agent::ZED_AGENT_ID.clone(),
+        title: title.clone(),
+        updated_at,
+        created_at: None,
+        folder_paths: path_list.clone(),
+        archived: false,
+    };
+    let session_id = session_id.clone();
+    cx.update(|_, cx| {
+        ThreadMetadataStore::global(cx).update(cx, |store, cx| store.save(metadata, cx));
+
+        let db_thread = agent::DbThread {
+            title,
+            messages: vec![agent::Message::User(agent::UserMessage {
+                id: acp_thread::UserMessageId::new(),
+                content: vec![agent::UserMessageContent::Text("Hello".to_string())],
+            })],
+            updated_at,
+            detailed_summary: None,
+            initial_project_snapshot: None,
+            cumulative_token_usage: Default::default(),
+            request_token_usage: Default::default(),
+            model: None,
+            profile: None,
+            imported: false,
+            subagent_context: None,
+            speed: None,
+            thinking_enabled: false,
+            thinking_effort: None,
+            draft_prompt: None,
+            ui_scroll_position: None,
+        };
+        ThreadStore::global(cx)
+            .update(cx, |store, cx| {
+                store.save_thread(session_id, db_thread, path_list, cx)
+            })
+            .detach_and_log_err(cx);
+    });
+    cx.run_until_parked();
+}
+
 fn open_and_focus_sidebar(sidebar: &Entity<Sidebar>, cx: &mut gpui::VisualTestContext) {
     let multi_workspace = sidebar.read_with(cx, |s, _| s.multi_workspace.upgrade());
     if let Some(multi_workspace) = multi_workspace {
@@ -2272,7 +2322,7 @@ async fn test_new_thread_button_works_after_adding_folder(cx: &mut TestAppContex
     // verify a new draft is created.
     let workspace = multi_workspace.read_with(cx, |mw, _cx| mw.workspace().clone());
     sidebar.update_in(cx, |sidebar, window, cx| {
-        sidebar.create_new_thread(&workspace, window, cx);
+        sidebar.create_new_thread(&workspace, None, window, cx);
     });
     cx.run_until_parked();
 
@@ -4994,16 +5044,47 @@ mod property_test {
             .unwrap()
             + chrono::Duration::seconds(state.thread_counter as i64);
         let metadata = ThreadMetadata {
-            session_id,
+            session_id: session_id.clone(),
             agent_id: agent::ZED_AGENT_ID.clone(),
-            title,
+            title: title.clone(),
             updated_at,
             created_at: None,
-            folder_paths: path_list,
+            folder_paths: path_list.clone(),
             archived: false,
         };
+        // Save to both stores: ThreadMetadataStore (used by sidebar for
+        // listing) and ThreadStore (used by NativeAgentServer for loading
+        // thread content). In production these are populated through
+        // different paths, but tests need both.
         cx.update(|_, cx| {
             ThreadMetadataStore::global(cx).update(cx, |store, cx| store.save(metadata, cx));
+
+            let db_thread = agent::DbThread {
+                title,
+                messages: vec![agent::Message::User(agent::UserMessage {
+                    id: acp_thread::UserMessageId::new(),
+                    content: vec![agent::UserMessageContent::Text("Hello".to_string())],
+                })],
+                updated_at,
+                detailed_summary: None,
+                initial_project_snapshot: None,
+                cumulative_token_usage: Default::default(),
+                request_token_usage: Default::default(),
+                model: None,
+                profile: None,
+                imported: false,
+                subagent_context: None,
+                speed: None,
+                thinking_enabled: false,
+                thinking_effort: None,
+                draft_prompt: None,
+                ui_scroll_position: None,
+            };
+            ThreadStore::global(cx)
+                .update(cx, |store, cx| {
+                    store.save_thread(session_id, db_thread, path_list, cx)
+                })
+                .detach_and_log_err(cx);
         });
     }
 
@@ -5018,9 +5099,24 @@ mod property_test {
             Operation::SaveThread { workspace_index } => {
                 let workspace =
                     multi_workspace.read_with(cx, |mw, _| mw.workspaces()[workspace_index].clone());
-                let path_list = workspace
-                    .read_with(cx, |workspace, cx| PathList::new(&workspace.root_paths(cx)));
-                save_thread_to_path(state, path_list, cx);
+                let panel =
+                    workspace.read_with(cx, |workspace, cx| workspace.panel::<AgentPanel>(cx));
+                if let Some(panel) = panel {
+                    let title = format!("Thread {}", state.thread_counter);
+                    let connection = StubAgentConnection::new();
+                    connection.set_next_prompt_updates(vec![
+                        acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(title.into())),
+                    ]);
+                    open_thread_with_connection(&panel, connection, cx);
+                    send_message(&panel, cx);
+                    let session_id = active_session_id(&panel, cx);
+                    state.thread_counter += 1;
+                    state.saved_thread_ids.push(session_id.clone());
+
+                    let path_list = workspace
+                        .read_with(cx, |workspace, cx| PathList::new(&workspace.root_paths(cx)));
+                    save_thread_with_content(&session_id, path_list, cx);
+                }
             }
             Operation::SaveWorktreeThread { worktree_index } => {
                 let worktree = &state.unopened_worktrees[worktree_index];
@@ -5068,33 +5164,8 @@ mod property_test {
                         .find(|m| m.session_id == session_id)
                 });
                 if let Some(metadata) = metadata {
-                    let panel =
-                        workspace.read_with(cx, |workspace, cx| workspace.panel::<AgentPanel>(cx));
-                    if let Some(panel) = panel {
-                        let connection = StubAgentConnection::new();
-                        connection.set_next_prompt_updates(vec![
-                            acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(
-                                metadata.title.to_string().into(),
-                            )),
-                        ]);
-                        open_thread_with_connection(&panel, connection, cx);
-                        send_message(&panel, cx);
-                        let panel_session_id = active_session_id(&panel, cx);
-                        // Replace the old metadata entry with one that
-                        // uses the panel's actual session ID.
-                        let old_session_id = metadata.session_id.clone();
-                        let mut updated_metadata = metadata.clone();
-                        updated_metadata.session_id = panel_session_id.clone();
-                        cx.update(|_, cx| {
-                            ThreadMetadataStore::global(cx).update(cx, |store, cx| {
-                                store.delete(old_session_id, cx);
-                                store.save(updated_metadata, cx);
-                            });
-                        });
-                        state.saved_thread_ids[index] = panel_session_id;
-                    }
-                    _sidebar.update_in(cx, |sidebar, _window, cx| {
-                        sidebar.update_entries(cx);
+                    _sidebar.update_in(cx, |sidebar, window, cx| {
+                        sidebar.activate_thread_locally(&metadata, &workspace, window, cx);
                     });
                 }
             }
@@ -5423,18 +5494,24 @@ mod property_test {
         raw_operations: Vec<u32>,
         cx: &mut TestAppContext,
     ) {
+        use std::sync::atomic::{AtomicUsize, Ordering};
+        static ITERATION: AtomicUsize = AtomicUsize::new(0);
+        let iteration = ITERATION.fetch_add(1, Ordering::SeqCst);
+        let project_path = format!("/my-project-{iteration}");
+        let db_name = format!("sidebar_invariants_{iteration}");
+
         agent_ui::test_support::init_test(cx);
         cx.update(|cx| {
             cx.update_flags(false, vec!["agent-v2".into()]);
             ThreadStore::init_global(cx);
-            ThreadMetadataStore::init_global(cx);
+            ThreadMetadataStore::init_global_with_name(&db_name, cx);
             language_model::LanguageModelRegistry::test(cx);
             prompt_store::init(cx);
         });
 
         let fs = FakeFs::new(cx.executor());
         fs.insert_tree(
-            "/my-project",
+            &project_path,
             serde_json::json!({
                 ".git": {},
                 "src": {},
@@ -5443,7 +5520,7 @@ mod property_test {
         .await;
         cx.update(|cx| <dyn fs::Fs>::set_global(fs.clone(), cx));
         let project =
-            project::Project::test(fs.clone() as Arc<dyn fs::Fs>, ["/my-project".as_ref()], cx)
+            project::Project::test(fs.clone() as Arc<dyn fs::Fs>, [project_path.as_ref()], cx)
                 .await;
         project.update(cx, |p, cx| p.git_scans_complete(cx)).await;
 
@@ -5451,7 +5528,7 @@ mod property_test {
             cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
         let (sidebar, _panel) = setup_sidebar_with_agent_panel(&multi_workspace, &project, cx);
 
-        let mut state = TestState::new(fs, "/my-project".to_string());
+        let mut state = TestState::new(fs, project_path);
         let mut executed: Vec<String> = Vec::new();
 
         for &raw_op in &raw_operations {