diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 77c326feec60514d459e6026a39f1bcd5ed8a896..6437fd1883c9ddbb256babbb88041b4c42293a95 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -82,7 +82,7 @@ struct Session { /// The ACP thread that handles protocol communication acp_thread: Entity, project_id: EntityId, - pending_save: Task<()>, + pending_save: Task>, _subscriptions: Vec, } @@ -387,7 +387,7 @@ impl NativeAgent { acp_thread: acp_thread.clone(), project_id, _subscriptions: subscriptions, - pending_save: Task::ready(()), + pending_save: Task::ready(Ok(())), }, ); @@ -1000,7 +1000,7 @@ impl NativeAgent { let thread_store = self.thread_store.clone(); session.pending_save = cx.spawn(async move |_, cx| { let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else { - return; + return Ok(()); }; let db_thread = db_thread.await; database @@ -1008,6 +1008,7 @@ impl NativeAgent { .await .log_err(); thread_store.update(cx, |store, cx| store.reload(cx)); + Ok(()) }); } @@ -1444,18 +1445,23 @@ impl acp_thread::AgentConnection for NativeAgentConnection { cx: &mut App, ) -> Task> { self.0.update(cx, |agent, cx| { + let thread = agent.sessions.get(session_id).map(|s| s.thread.clone()); + if let Some(thread) = thread { + agent.save_thread(thread, cx); + } + let Some(session) = agent.sessions.remove(session_id) else { - return; + return Task::ready(Ok(())); }; let project_id = session.project_id; - agent.save_thread(session.thread, cx); let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id); if !has_remaining { agent.projects.remove(&project_id); } - }); - Task::ready(Ok(())) + + session.pending_save + }) } fn auth_methods(&self) -> &[acp::AuthMethod] { @@ -2830,7 +2836,9 @@ mod internal_tests { cx.run_until_parked(); - // Set a draft prompt with rich content blocks before saving. + // Set a draft prompt with rich content blocks and scroll position + // AFTER run_until_parked, so the only save that captures these + // changes is the one performed by close_session itself. let draft_blocks = vec![ acp::ContentBlock::Text(acp::TextContent::new("Check out ")), acp::ContentBlock::ResourceLink(acp::ResourceLink::new("b.md", uri.to_string())), @@ -2845,8 +2853,6 @@ mod internal_tests { offset_in_item: gpui::px(12.5), })); }); - thread.update(cx, |_thread, cx| cx.notify()); - cx.run_until_parked(); // Close the session so it can be reloaded from disk. cx.update(|cx| connection.clone().close_session(&session_id, cx)) @@ -2912,6 +2918,87 @@ mod internal_tests { }); } + #[gpui::test] + async fn test_close_session_saves_thread(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "file.txt": "hello" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = cx.update(|cx| { + NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) + }); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + + let model = Arc::new(FakeLanguageModel::default()); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + + // Send a message so the thread is non-empty (empty threads aren't saved). + let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + + model.send_last_completion_stream_text_chunk("world"); + model.end_last_completion_stream(); + send.await.unwrap(); + cx.run_until_parked(); + + // Set a draft prompt WITHOUT calling run_until_parked afterwards. + // This means no observe-triggered save has run for this change. + // The only way this data gets persisted is if close_session + // itself performs the save. + let draft_blocks = vec![acp::ContentBlock::Text(acp::TextContent::new( + "unsaved draft", + ))]; + acp_thread.update(cx, |thread, _cx| { + thread.set_draft_prompt(Some(draft_blocks.clone())); + }); + + // Close the session immediately — no run_until_parked in between. + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + cx.run_until_parked(); + + // Reopen and verify the draft prompt was saved. + let reloaded = agent + .update(cx, |agent, cx| { + agent.open_thread(session_id.clone(), project.clone(), cx) + }) + .await + .unwrap(); + reloaded.read_with(cx, |thread, _| { + assert_eq!( + thread.draft_prompt(), + Some(draft_blocks.as_slice()), + "close_session must save the thread; draft prompt was lost" + ); + }); + } + fn thread_entries( thread_store: &Entity, cx: &mut TestAppContext,