diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 143f74bfa90a1bb686bc88dc9942816ca42510ee..eaa8de69fca2d11434059a945ce05deddf56cb20 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -84,6 +84,12 @@ struct Session { project_id: EntityId, pending_save: Task>, _subscriptions: Vec, + ref_count: usize, +} + +struct PendingSession { + task: Shared, Arc>>>, + ref_count: usize, } pub struct LanguageModels { @@ -245,6 +251,7 @@ impl LanguageModels { pub struct NativeAgent { /// Session ID -> Session mapping sessions: HashMap, + pending_sessions: HashMap, thread_store: Entity, /// Project-specific state keyed by project EntityId projects: HashMap, @@ -278,6 +285,7 @@ impl NativeAgent { Self { sessions: HashMap::default(), + pending_sessions: HashMap::default(), thread_store, projects: HashMap::default(), templates, @@ -316,13 +324,14 @@ impl NativeAgent { ) }); - self.register_session(thread, project_id, cx) + self.register_session(thread, project_id, 1, cx) } fn register_session( &mut self, thread_handle: Entity, project_id: EntityId, + ref_count: usize, cx: &mut Context, ) -> Entity { let connection = Rc::new(NativeAgentConnection(cx.entity())); @@ -388,6 +397,7 @@ impl NativeAgent { project_id, _subscriptions: subscriptions, pending_save: Task::ready(Ok(())), + ref_count, }, ); @@ -926,27 +936,68 @@ impl NativeAgent { project: Entity, cx: &mut Context, ) -> Task>> { - if let Some(session) = self.sessions.get(&id) { + if let Some(session) = self.sessions.get_mut(&id) { + session.ref_count += 1; return Task::ready(Ok(session.acp_thread.clone())); } - let task = self.load_thread(id, project.clone(), cx); - cx.spawn(async move |this, cx| { - let thread = task.await?; - let acp_thread = this.update(cx, |this, cx| { - let project_id = this.get_or_create_project_state(&project, cx); - this.register_session(thread.clone(), project_id, cx) - })?; - let events = thread.update(cx, |thread, cx| thread.replay(cx)); - cx.update(|cx| { - NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx) + if let Some(pending) = self.pending_sessions.get_mut(&id) { + pending.ref_count += 1; + let task = pending.task.clone(); + return cx.spawn(async move |_, _cx| task.await.map_err(|err| anyhow!(err))); + } + + let task = self.load_thread(id.clone(), project.clone(), cx); + let shared_task = cx + .spawn({ + let id = id.clone(); + async move |this, cx| { + let thread = match task.await { + Ok(thread) => thread, + Err(err) => { + this.update(cx, |this, _cx| { + this.pending_sessions.remove(&id); + }) + .ok(); + return Err(Arc::new(err)); + } + }; + let acp_thread = this + .update(cx, |this, cx| { + let project_id = this.get_or_create_project_state(&project, cx); + let ref_count = this + .pending_sessions + .remove(&id) + .map_or(1, |pending| pending.ref_count); + this.register_session(thread.clone(), project_id, ref_count, cx) + }) + .map_err(Arc::new)?; + let events = thread.update(cx, |thread, cx| thread.replay(cx)); + cx.update(|cx| { + NativeAgentConnection::handle_thread_events( + events, + acp_thread.downgrade(), + cx, + ) + }) + .await + .map_err(Arc::new)?; + acp_thread.update(cx, |thread, cx| { + thread.snapshot_completed_plan(cx); + }); + Ok(acp_thread) + } }) - .await?; - acp_thread.update(cx, |thread, cx| { - thread.snapshot_completed_plan(cx); - }); - Ok(acp_thread) - }) + .shared(); + self.pending_sessions.insert( + id, + PendingSession { + task: shared_task.clone(), + ref_count: 1, + }, + ); + + cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) }) } pub fn thread_summary( @@ -968,11 +1019,43 @@ impl NativeAgent { })? .await .context("Failed to generate summary")?; + + this.update(cx, |this, cx| this.close_session(&id, cx))? + .await?; drop(acp_thread); Ok(result) }) } + fn close_session( + &mut self, + session_id: &acp::SessionId, + cx: &mut Context, + ) -> Task> { + let Some(session) = self.sessions.get_mut(session_id) else { + return Task::ready(Ok(())); + }; + + session.ref_count -= 1; + if session.ref_count > 0 { + return Task::ready(Ok(())); + } + + let thread = session.thread.clone(); + self.save_thread(thread, cx); + let Some(session) = self.sessions.remove(session_id) else { + return Task::ready(Ok(())); + }; + let project_id = session.project_id; + + let has_remaining = self.sessions.values().any(|s| s.project_id == project_id); + if !has_remaining { + self.projects.remove(&project_id); + } + + session.pending_save + } + fn save_thread(&mut self, thread: Entity, cx: &mut Context) { if thread.read(cx).is_empty() { return; @@ -1158,6 +1241,7 @@ impl NativeAgentConnection { .get_mut(&session_id) .map(|s| (s.thread.clone(), s.acp_thread.clone())) }) else { + log::error!("Session not found in run_turn: {}", session_id); return Task::ready(Err(anyhow!("Session not found"))); }; log::debug!("Found session for: {}", session_id); @@ -1452,24 +1536,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection { session_id: &acp::SessionId, cx: &mut App, ) -> Task> { - self.0.update(cx, |agent, cx| { - let thread = agent.sessions.get(session_id).map(|s| s.thread.clone()); - if let Some(thread) = thread { - agent.save_thread(thread, cx); - } - - let Some(session) = agent.sessions.remove(session_id) else { - return Task::ready(Ok(())); - }; - let project_id = session.project_id; - - let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id); - if !has_remaining { - agent.projects.remove(&project_id); - } - - session.pending_save - }) + self.0 + .update(cx, |agent, cx| agent.close_session(session_id, cx)) } fn auth_methods(&self) -> &[acp::AuthMethod] { @@ -1498,6 +1566,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection { log::debug!("Prompt blocks count: {}", params.prompt.len()); let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else { + log::error!("Session not found in prompt: {}", session_id); + if self.0.read(cx).sessions.contains_key(&session_id) { + log::error!( + "Session found in sessions map, but not in project state: {}", + session_id + ); + } return Task::ready(Err(anyhow::anyhow!("Session not found"))); }; @@ -1812,7 +1887,7 @@ impl NativeThreadEnvironment { .get(&parent_session_id) .map(|s| s.project_id) .context("parent session not found")?; - Ok(agent.register_session(subagent_thread.clone(), project_id, cx)) + Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx)) })??; let depth = current_depth + 1; @@ -3006,6 +3081,241 @@ mod internal_tests { }); } + #[gpui::test] + async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "file.txt": "hello" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = cx.update(|cx| { + NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) + }); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + + let model = Arc::new(FakeLanguageModel::default()); + let summary_model = Arc::new(FakeLanguageModel::default()); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + thread.set_summarization_model(Some(summary_model.clone()), cx); + }); + + let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + + model.send_last_completion_stream_text_chunk("world"); + model.end_last_completion_stream(); + send.await.unwrap(); + cx.run_until_parked(); + + let summary = agent.update(cx, |agent, cx| { + agent.thread_summary(session_id.clone(), project.clone(), cx) + }); + cx.run_until_parked(); + + summary_model.send_last_completion_stream_text_chunk("summary"); + summary_model.end_last_completion_stream(); + + assert_eq!(summary.await.unwrap(), "summary"); + cx.run_until_parked(); + + agent.read_with(cx, |agent, _| { + let session = agent + .sessions + .get(&session_id) + .expect("thread_summary should not close the active session"); + assert_eq!( + session.ref_count, 1, + "thread_summary should release its temporary session reference" + ); + }); + + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + cx.run_until_parked(); + + agent.read_with(cx, |agent, _| { + assert!( + agent.sessions.is_empty(), + "closing the active session after thread_summary should unload it" + ); + }); + } + + #[gpui::test] + async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "file.txt": "hello" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = cx.update(|cx| { + NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx) + }); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_session(project.clone(), PathList::new(&[Path::new("")]), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + + let model = cx.update(|cx| { + LanguageModelRegistry::read_global(cx) + .default_model() + .map(|default_model| default_model.model) + .expect("default test model should be available") + }); + let fake_model = model.as_fake(); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + + let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("world"); + fake_model.end_last_completion_stream(); + send.await.unwrap(); + cx.run_until_parked(); + + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + drop(thread); + drop(acp_thread); + agent.read_with(cx, |agent, _| { + assert!(agent.sessions.is_empty()); + }); + + let first_loaded_thread = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + PathList::new(&[Path::new("")]), + None, + cx, + ) + }); + let second_loaded_thread = cx.update(|cx| { + connection.clone().load_session( + session_id.clone(), + project.clone(), + PathList::new(&[Path::new("")]), + None, + cx, + ) + }); + + let first_loaded_thread = first_loaded_thread.await.unwrap(); + let second_loaded_thread = second_loaded_thread.await.unwrap(); + + cx.run_until_parked(); + + assert_eq!( + first_loaded_thread.entity_id(), + second_loaded_thread.entity_id(), + "concurrent loads for the same session should share one AcpThread" + ); + + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + + agent.read_with(cx, |agent, _| { + assert!( + agent.sessions.contains_key(&session_id), + "closing one loaded session should not drop shared session state" + ); + }); + + let follow_up = second_loaded_thread.update(cx, |thread, cx| { + thread.send(vec!["still there?".into()], cx) + }); + let follow_up = cx.foreground_executor().spawn(follow_up); + cx.run_until_parked(); + + fake_model.send_last_completion_stream_text_chunk("yes"); + fake_model.end_last_completion_stream(); + follow_up.await.unwrap(); + cx.run_until_parked(); + + second_loaded_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + formatdoc! {" + ## User + + hello + + ## Assistant + + world + + ## User + + still there? + + ## Assistant + + yes + + "} + ); + }); + + cx.update(|cx| connection.clone().close_session(&session_id, cx)) + .await + .unwrap(); + + cx.run_until_parked(); + + drop(first_loaded_thread); + drop(second_loaded_thread); + agent.read_with(cx, |agent, _| { + assert!(agent.sessions.is_empty()); + }); + } + #[gpui::test] async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) { // Regression test: rapid title changes must not cause a propagation loop