agent: Fix session not found error (#53999)

Bennet Bo Fenner created

Self-Review Checklist:

- [x] I've reviewed my own diff for quality, security, and reliability
- [x] Unsafe blocks (if any) have justifying comments
- [x] The content is consistent with the [UI/UX
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)
- [x] Tests cover the new/changed behavior
- [x] Performance impact has been considered and is acceptable

Release Notes:

- N/A

Change summary

crates/agent/src/agent.rs | 384 +++++++++++++++++++++++++++++++++++++---
1 file changed, 347 insertions(+), 37 deletions(-)

Detailed changes

crates/agent/src/agent.rs 🔗

@@ -84,6 +84,12 @@ struct Session {
     project_id: EntityId,
     pending_save: Task<Result<()>>,
     _subscriptions: Vec<Subscription>,
+    ref_count: usize,
+}
+
+struct PendingSession {
+    task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
+    ref_count: usize,
 }
 
 pub struct LanguageModels {
@@ -245,6 +251,7 @@ impl LanguageModels {
 pub struct NativeAgent {
     /// Session ID -> Session mapping
     sessions: HashMap<acp::SessionId, Session>,
+    pending_sessions: HashMap<acp::SessionId, PendingSession>,
     thread_store: Entity<ThreadStore>,
     /// Project-specific state keyed by project EntityId
     projects: HashMap<EntityId, ProjectState>,
@@ -278,6 +285,7 @@ impl NativeAgent {
 
             Self {
                 sessions: HashMap::default(),
+                pending_sessions: HashMap::default(),
                 thread_store,
                 projects: HashMap::default(),
                 templates,
@@ -316,13 +324,14 @@ impl NativeAgent {
             )
         });
 
-        self.register_session(thread, project_id, cx)
+        self.register_session(thread, project_id, 1, cx)
     }
 
     fn register_session(
         &mut self,
         thread_handle: Entity<Thread>,
         project_id: EntityId,
+        ref_count: usize,
         cx: &mut Context<Self>,
     ) -> Entity<AcpThread> {
         let connection = Rc::new(NativeAgentConnection(cx.entity()));
@@ -388,6 +397,7 @@ impl NativeAgent {
                 project_id,
                 _subscriptions: subscriptions,
                 pending_save: Task::ready(Ok(())),
+                ref_count,
             },
         );
 
@@ -926,27 +936,68 @@ impl NativeAgent {
         project: Entity<Project>,
         cx: &mut Context<Self>,
     ) -> Task<Result<Entity<AcpThread>>> {
-        if let Some(session) = self.sessions.get(&id) {
+        if let Some(session) = self.sessions.get_mut(&id) {
+            session.ref_count += 1;
             return Task::ready(Ok(session.acp_thread.clone()));
         }
 
-        let task = self.load_thread(id, project.clone(), cx);
-        cx.spawn(async move |this, cx| {
-            let thread = task.await?;
-            let acp_thread = this.update(cx, |this, cx| {
-                let project_id = this.get_or_create_project_state(&project, cx);
-                this.register_session(thread.clone(), project_id, cx)
-            })?;
-            let events = thread.update(cx, |thread, cx| thread.replay(cx));
-            cx.update(|cx| {
-                NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
+        if let Some(pending) = self.pending_sessions.get_mut(&id) {
+            pending.ref_count += 1;
+            let task = pending.task.clone();
+            return cx.spawn(async move |_, _cx| task.await.map_err(|err| anyhow!(err)));
+        }
+
+        let task = self.load_thread(id.clone(), project.clone(), cx);
+        let shared_task = cx
+            .spawn({
+                let id = id.clone();
+                async move |this, cx| {
+                    let thread = match task.await {
+                        Ok(thread) => thread,
+                        Err(err) => {
+                            this.update(cx, |this, _cx| {
+                                this.pending_sessions.remove(&id);
+                            })
+                            .ok();
+                            return Err(Arc::new(err));
+                        }
+                    };
+                    let acp_thread = this
+                        .update(cx, |this, cx| {
+                            let project_id = this.get_or_create_project_state(&project, cx);
+                            let ref_count = this
+                                .pending_sessions
+                                .remove(&id)
+                                .map_or(1, |pending| pending.ref_count);
+                            this.register_session(thread.clone(), project_id, ref_count, cx)
+                        })
+                        .map_err(Arc::new)?;
+                    let events = thread.update(cx, |thread, cx| thread.replay(cx));
+                    cx.update(|cx| {
+                        NativeAgentConnection::handle_thread_events(
+                            events,
+                            acp_thread.downgrade(),
+                            cx,
+                        )
+                    })
+                    .await
+                    .map_err(Arc::new)?;
+                    acp_thread.update(cx, |thread, cx| {
+                        thread.snapshot_completed_plan(cx);
+                    });
+                    Ok(acp_thread)
+                }
             })
-            .await?;
-            acp_thread.update(cx, |thread, cx| {
-                thread.snapshot_completed_plan(cx);
-            });
-            Ok(acp_thread)
-        })
+            .shared();
+        self.pending_sessions.insert(
+            id,
+            PendingSession {
+                task: shared_task.clone(),
+                ref_count: 1,
+            },
+        );
+
+        cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
     }
 
     pub fn thread_summary(
@@ -968,11 +1019,43 @@ impl NativeAgent {
                 })?
                 .await
                 .context("Failed to generate summary")?;
+
+            this.update(cx, |this, cx| this.close_session(&id, cx))?
+                .await?;
             drop(acp_thread);
             Ok(result)
         })
     }
 
+    fn close_session(
+        &mut self,
+        session_id: &acp::SessionId,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<()>> {
+        let Some(session) = self.sessions.get_mut(session_id) else {
+            return Task::ready(Ok(()));
+        };
+
+        session.ref_count -= 1;
+        if session.ref_count > 0 {
+            return Task::ready(Ok(()));
+        }
+
+        let thread = session.thread.clone();
+        self.save_thread(thread, cx);
+        let Some(session) = self.sessions.remove(session_id) else {
+            return Task::ready(Ok(()));
+        };
+        let project_id = session.project_id;
+
+        let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
+        if !has_remaining {
+            self.projects.remove(&project_id);
+        }
+
+        session.pending_save
+    }
+
     fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
         if thread.read(cx).is_empty() {
             return;
@@ -1158,6 +1241,7 @@ impl NativeAgentConnection {
                 .get_mut(&session_id)
                 .map(|s| (s.thread.clone(), s.acp_thread.clone()))
         }) else {
+            log::error!("Session not found in run_turn: {}", session_id);
             return Task::ready(Err(anyhow!("Session not found")));
         };
         log::debug!("Found session for: {}", session_id);
@@ -1452,24 +1536,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         session_id: &acp::SessionId,
         cx: &mut App,
     ) -> Task<Result<()>> {
-        self.0.update(cx, |agent, cx| {
-            let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
-            if let Some(thread) = thread {
-                agent.save_thread(thread, cx);
-            }
-
-            let Some(session) = agent.sessions.remove(session_id) else {
-                return Task::ready(Ok(()));
-            };
-            let project_id = session.project_id;
-
-            let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
-            if !has_remaining {
-                agent.projects.remove(&project_id);
-            }
-
-            session.pending_save
-        })
+        self.0
+            .update(cx, |agent, cx| agent.close_session(session_id, cx))
     }
 
     fn auth_methods(&self) -> &[acp::AuthMethod] {
@@ -1498,6 +1566,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
         log::debug!("Prompt blocks count: {}", params.prompt.len());
 
         let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
+            log::error!("Session not found in prompt: {}", session_id);
+            if self.0.read(cx).sessions.contains_key(&session_id) {
+                log::error!(
+                    "Session found in sessions map, but not in project state: {}",
+                    session_id
+                );
+            }
             return Task::ready(Err(anyhow::anyhow!("Session not found")));
         };
 
@@ -1812,7 +1887,7 @@ impl NativeThreadEnvironment {
                     .get(&parent_session_id)
                     .map(|s| s.project_id)
                     .context("parent session not found")?;
-                Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
+                Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
             })??;
 
         let depth = current_depth + 1;
@@ -3006,6 +3081,241 @@ mod internal_tests {
         });
     }
 
+    #[gpui::test]
+    async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/",
+            json!({
+                "a": {
+                    "file.txt": "hello"
+                }
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+        let thread_store = cx.new(|cx| ThreadStore::new(cx));
+        let agent = cx.update(|cx| {
+            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+        });
+        let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+        let acp_thread = cx
+            .update(|cx| {
+                connection
+                    .clone()
+                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
+            })
+            .await
+            .unwrap();
+        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+        let thread = agent.read_with(cx, |agent, _| {
+            agent.sessions.get(&session_id).unwrap().thread.clone()
+        });
+
+        let model = Arc::new(FakeLanguageModel::default());
+        let summary_model = Arc::new(FakeLanguageModel::default());
+        thread.update(cx, |thread, cx| {
+            thread.set_model(model.clone(), cx);
+            thread.set_summarization_model(Some(summary_model.clone()), cx);
+        });
+
+        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
+        let send = cx.foreground_executor().spawn(send);
+        cx.run_until_parked();
+
+        model.send_last_completion_stream_text_chunk("world");
+        model.end_last_completion_stream();
+        send.await.unwrap();
+        cx.run_until_parked();
+
+        let summary = agent.update(cx, |agent, cx| {
+            agent.thread_summary(session_id.clone(), project.clone(), cx)
+        });
+        cx.run_until_parked();
+
+        summary_model.send_last_completion_stream_text_chunk("summary");
+        summary_model.end_last_completion_stream();
+
+        assert_eq!(summary.await.unwrap(), "summary");
+        cx.run_until_parked();
+
+        agent.read_with(cx, |agent, _| {
+            let session = agent
+                .sessions
+                .get(&session_id)
+                .expect("thread_summary should not close the active session");
+            assert_eq!(
+                session.ref_count, 1,
+                "thread_summary should release its temporary session reference"
+            );
+        });
+
+        cx.update(|cx| connection.clone().close_session(&session_id, cx))
+            .await
+            .unwrap();
+        cx.run_until_parked();
+
+        agent.read_with(cx, |agent, _| {
+            assert!(
+                agent.sessions.is_empty(),
+                "closing the active session after thread_summary should unload it"
+            );
+        });
+    }
+
+    #[gpui::test]
+    async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            "/",
+            json!({
+                "a": {
+                    "file.txt": "hello"
+                }
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+        let thread_store = cx.new(|cx| ThreadStore::new(cx));
+        let agent = cx.update(|cx| {
+            NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+        });
+        let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+        let acp_thread = cx
+            .update(|cx| {
+                connection
+                    .clone()
+                    .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
+            })
+            .await
+            .unwrap();
+        let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+        let thread = agent.read_with(cx, |agent, _| {
+            agent.sessions.get(&session_id).unwrap().thread.clone()
+        });
+
+        let model = cx.update(|cx| {
+            LanguageModelRegistry::read_global(cx)
+                .default_model()
+                .map(|default_model| default_model.model)
+                .expect("default test model should be available")
+        });
+        let fake_model = model.as_fake();
+        thread.update(cx, |thread, cx| {
+            thread.set_model(model.clone(), cx);
+        });
+
+        let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
+        let send = cx.foreground_executor().spawn(send);
+        cx.run_until_parked();
+
+        fake_model.send_last_completion_stream_text_chunk("world");
+        fake_model.end_last_completion_stream();
+        send.await.unwrap();
+        cx.run_until_parked();
+
+        cx.update(|cx| connection.clone().close_session(&session_id, cx))
+            .await
+            .unwrap();
+        drop(thread);
+        drop(acp_thread);
+        agent.read_with(cx, |agent, _| {
+            assert!(agent.sessions.is_empty());
+        });
+
+        let first_loaded_thread = cx.update(|cx| {
+            connection.clone().load_session(
+                session_id.clone(),
+                project.clone(),
+                PathList::new(&[Path::new("")]),
+                None,
+                cx,
+            )
+        });
+        let second_loaded_thread = cx.update(|cx| {
+            connection.clone().load_session(
+                session_id.clone(),
+                project.clone(),
+                PathList::new(&[Path::new("")]),
+                None,
+                cx,
+            )
+        });
+
+        let first_loaded_thread = first_loaded_thread.await.unwrap();
+        let second_loaded_thread = second_loaded_thread.await.unwrap();
+
+        cx.run_until_parked();
+
+        assert_eq!(
+            first_loaded_thread.entity_id(),
+            second_loaded_thread.entity_id(),
+            "concurrent loads for the same session should share one AcpThread"
+        );
+
+        cx.update(|cx| connection.clone().close_session(&session_id, cx))
+            .await
+            .unwrap();
+
+        agent.read_with(cx, |agent, _| {
+            assert!(
+                agent.sessions.contains_key(&session_id),
+                "closing one loaded session should not drop shared session state"
+            );
+        });
+
+        let follow_up = second_loaded_thread.update(cx, |thread, cx| {
+            thread.send(vec!["still there?".into()], cx)
+        });
+        let follow_up = cx.foreground_executor().spawn(follow_up);
+        cx.run_until_parked();
+
+        fake_model.send_last_completion_stream_text_chunk("yes");
+        fake_model.end_last_completion_stream();
+        follow_up.await.unwrap();
+        cx.run_until_parked();
+
+        second_loaded_thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                formatdoc! {"
+                    ## User
+
+                    hello
+
+                    ## Assistant
+
+                    world
+
+                    ## User
+
+                    still there?
+
+                    ## Assistant
+
+                    yes
+
+                "}
+            );
+        });
+
+        cx.update(|cx| connection.clone().close_session(&session_id, cx))
+            .await
+            .unwrap();
+
+        cx.run_until_parked();
+
+        drop(first_loaded_thread);
+        drop(second_loaded_thread);
+        agent.read_with(cx, |agent, _| {
+            assert!(agent.sessions.is_empty());
+        });
+    }
+
     #[gpui::test]
     async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
         // Regression test: rapid title changes must not cause a propagation loop