@@ -84,6 +84,12 @@ struct Session {
project_id: EntityId,
pending_save: Task<Result<()>>,
_subscriptions: Vec<Subscription>,
+ ref_count: usize,
+}
+
+struct PendingSession {
+ task: Shared<Task<Result<Entity<AcpThread>, Arc<anyhow::Error>>>>,
+ ref_count: usize,
}
pub struct LanguageModels {
@@ -245,6 +251,7 @@ impl LanguageModels {
pub struct NativeAgent {
/// Session ID -> Session mapping
sessions: HashMap<acp::SessionId, Session>,
+ pending_sessions: HashMap<acp::SessionId, PendingSession>,
thread_store: Entity<ThreadStore>,
/// Project-specific state keyed by project EntityId
projects: HashMap<EntityId, ProjectState>,
@@ -278,6 +285,7 @@ impl NativeAgent {
Self {
sessions: HashMap::default(),
+ pending_sessions: HashMap::default(),
thread_store,
projects: HashMap::default(),
templates,
@@ -316,13 +324,14 @@ impl NativeAgent {
)
});
- self.register_session(thread, project_id, cx)
+ self.register_session(thread, project_id, 1, cx)
}
fn register_session(
&mut self,
thread_handle: Entity<Thread>,
project_id: EntityId,
+ ref_count: usize,
cx: &mut Context<Self>,
) -> Entity<AcpThread> {
let connection = Rc::new(NativeAgentConnection(cx.entity()));
@@ -388,6 +397,7 @@ impl NativeAgent {
project_id,
_subscriptions: subscriptions,
pending_save: Task::ready(Ok(())),
+ ref_count,
},
);
@@ -926,27 +936,68 @@ impl NativeAgent {
project: Entity<Project>,
cx: &mut Context<Self>,
) -> Task<Result<Entity<AcpThread>>> {
- if let Some(session) = self.sessions.get(&id) {
+ if let Some(session) = self.sessions.get_mut(&id) {
+ session.ref_count += 1;
return Task::ready(Ok(session.acp_thread.clone()));
}
- let task = self.load_thread(id, project.clone(), cx);
- cx.spawn(async move |this, cx| {
- let thread = task.await?;
- let acp_thread = this.update(cx, |this, cx| {
- let project_id = this.get_or_create_project_state(&project, cx);
- this.register_session(thread.clone(), project_id, cx)
- })?;
- let events = thread.update(cx, |thread, cx| thread.replay(cx));
- cx.update(|cx| {
- NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx)
+ if let Some(pending) = self.pending_sessions.get_mut(&id) {
+ pending.ref_count += 1;
+ let task = pending.task.clone();
+ return cx.spawn(async move |_, _cx| task.await.map_err(|err| anyhow!(err)));
+ }
+
+ let task = self.load_thread(id.clone(), project.clone(), cx);
+ let shared_task = cx
+ .spawn({
+ let id = id.clone();
+ async move |this, cx| {
+ let thread = match task.await {
+ Ok(thread) => thread,
+ Err(err) => {
+ this.update(cx, |this, _cx| {
+ this.pending_sessions.remove(&id);
+ })
+ .ok();
+ return Err(Arc::new(err));
+ }
+ };
+ let acp_thread = this
+ .update(cx, |this, cx| {
+ let project_id = this.get_or_create_project_state(&project, cx);
+ let ref_count = this
+ .pending_sessions
+ .remove(&id)
+ .map_or(1, |pending| pending.ref_count);
+ this.register_session(thread.clone(), project_id, ref_count, cx)
+ })
+ .map_err(Arc::new)?;
+ let events = thread.update(cx, |thread, cx| thread.replay(cx));
+ cx.update(|cx| {
+ NativeAgentConnection::handle_thread_events(
+ events,
+ acp_thread.downgrade(),
+ cx,
+ )
+ })
+ .await
+ .map_err(Arc::new)?;
+ acp_thread.update(cx, |thread, cx| {
+ thread.snapshot_completed_plan(cx);
+ });
+ Ok(acp_thread)
+ }
})
- .await?;
- acp_thread.update(cx, |thread, cx| {
- thread.snapshot_completed_plan(cx);
- });
- Ok(acp_thread)
- })
+ .shared();
+ self.pending_sessions.insert(
+ id,
+ PendingSession {
+ task: shared_task.clone(),
+ ref_count: 1,
+ },
+ );
+
+ cx.background_spawn(async move { shared_task.await.map_err(|err| anyhow!(err)) })
}
pub fn thread_summary(
@@ -968,11 +1019,43 @@ impl NativeAgent {
})?
.await
.context("Failed to generate summary")?;
+
+ this.update(cx, |this, cx| this.close_session(&id, cx))?
+ .await?;
drop(acp_thread);
Ok(result)
})
}
+ fn close_session(
+ &mut self,
+ session_id: &acp::SessionId,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ let Some(session) = self.sessions.get_mut(session_id) else {
+ return Task::ready(Ok(()));
+ };
+
+ session.ref_count -= 1;
+ if session.ref_count > 0 {
+ return Task::ready(Ok(()));
+ }
+
+ let thread = session.thread.clone();
+ self.save_thread(thread, cx);
+ let Some(session) = self.sessions.remove(session_id) else {
+ return Task::ready(Ok(()));
+ };
+ let project_id = session.project_id;
+
+ let has_remaining = self.sessions.values().any(|s| s.project_id == project_id);
+ if !has_remaining {
+ self.projects.remove(&project_id);
+ }
+
+ session.pending_save
+ }
+
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
if thread.read(cx).is_empty() {
return;
@@ -1158,6 +1241,7 @@ impl NativeAgentConnection {
.get_mut(&session_id)
.map(|s| (s.thread.clone(), s.acp_thread.clone()))
}) else {
+ log::error!("Session not found in run_turn: {}", session_id);
return Task::ready(Err(anyhow!("Session not found")));
};
log::debug!("Found session for: {}", session_id);
@@ -1452,24 +1536,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
session_id: &acp::SessionId,
cx: &mut App,
) -> Task<Result<()>> {
- self.0.update(cx, |agent, cx| {
- let thread = agent.sessions.get(session_id).map(|s| s.thread.clone());
- if let Some(thread) = thread {
- agent.save_thread(thread, cx);
- }
-
- let Some(session) = agent.sessions.remove(session_id) else {
- return Task::ready(Ok(()));
- };
- let project_id = session.project_id;
-
- let has_remaining = agent.sessions.values().any(|s| s.project_id == project_id);
- if !has_remaining {
- agent.projects.remove(&project_id);
- }
-
- session.pending_save
- })
+ self.0
+ .update(cx, |agent, cx| agent.close_session(session_id, cx))
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
@@ -1498,6 +1566,13 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::debug!("Prompt blocks count: {}", params.prompt.len());
let Some(project_state) = self.0.read(cx).session_project_state(&session_id) else {
+ log::error!("Session not found in prompt: {}", session_id);
+ if self.0.read(cx).sessions.contains_key(&session_id) {
+ log::error!(
+ "Session found in sessions map, but not in project state: {}",
+ session_id
+ );
+ }
return Task::ready(Err(anyhow::anyhow!("Session not found")));
};
@@ -1812,7 +1887,7 @@ impl NativeThreadEnvironment {
.get(&parent_session_id)
.map(|s| s.project_id)
.context("parent session not found")?;
- Ok(agent.register_session(subagent_thread.clone(), project_id, cx))
+ Ok(agent.register_session(subagent_thread.clone(), project_id, 1, cx))
})??;
let depth = current_depth + 1;
@@ -3006,6 +3081,241 @@ mod internal_tests {
});
}
+ #[gpui::test]
+ async fn test_thread_summary_releases_loaded_session(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/",
+ json!({
+ "a": {
+ "file.txt": "hello"
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+ let thread_store = cx.new(|cx| ThreadStore::new(cx));
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
+ let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+ let acp_thread = cx
+ .update(|cx| {
+ connection
+ .clone()
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
+ })
+ .await
+ .unwrap();
+ let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+ let thread = agent.read_with(cx, |agent, _| {
+ agent.sessions.get(&session_id).unwrap().thread.clone()
+ });
+
+ let model = Arc::new(FakeLanguageModel::default());
+ let summary_model = Arc::new(FakeLanguageModel::default());
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model.clone(), cx);
+ thread.set_summarization_model(Some(summary_model.clone()), cx);
+ });
+
+ let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
+ let send = cx.foreground_executor().spawn(send);
+ cx.run_until_parked();
+
+ model.send_last_completion_stream_text_chunk("world");
+ model.end_last_completion_stream();
+ send.await.unwrap();
+ cx.run_until_parked();
+
+ let summary = agent.update(cx, |agent, cx| {
+ agent.thread_summary(session_id.clone(), project.clone(), cx)
+ });
+ cx.run_until_parked();
+
+ summary_model.send_last_completion_stream_text_chunk("summary");
+ summary_model.end_last_completion_stream();
+
+ assert_eq!(summary.await.unwrap(), "summary");
+ cx.run_until_parked();
+
+ agent.read_with(cx, |agent, _| {
+ let session = agent
+ .sessions
+ .get(&session_id)
+ .expect("thread_summary should not close the active session");
+ assert_eq!(
+ session.ref_count, 1,
+ "thread_summary should release its temporary session reference"
+ );
+ });
+
+ cx.update(|cx| connection.clone().close_session(&session_id, cx))
+ .await
+ .unwrap();
+ cx.run_until_parked();
+
+ agent.read_with(cx, |agent, _| {
+ assert!(
+ agent.sessions.is_empty(),
+ "closing the active session after thread_summary should unload it"
+ );
+ });
+ }
+
+ #[gpui::test]
+ async fn test_loaded_sessions_keep_state_until_last_close(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/",
+ json!({
+ "a": {
+ "file.txt": "hello"
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+ let thread_store = cx.new(|cx| ThreadStore::new(cx));
+ let agent = cx.update(|cx| {
+ NativeAgent::new(thread_store.clone(), Templates::new(), None, fs.clone(), cx)
+ });
+ let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+ let acp_thread = cx
+ .update(|cx| {
+ connection
+ .clone()
+ .new_session(project.clone(), PathList::new(&[Path::new("")]), cx)
+ })
+ .await
+ .unwrap();
+ let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+ let thread = agent.read_with(cx, |agent, _| {
+ agent.sessions.get(&session_id).unwrap().thread.clone()
+ });
+
+ let model = cx.update(|cx| {
+ LanguageModelRegistry::read_global(cx)
+ .default_model()
+ .map(|default_model| default_model.model)
+ .expect("default test model should be available")
+ });
+ let fake_model = model.as_fake();
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model.clone(), cx);
+ });
+
+ let send = acp_thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx));
+ let send = cx.foreground_executor().spawn(send);
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("world");
+ fake_model.end_last_completion_stream();
+ send.await.unwrap();
+ cx.run_until_parked();
+
+ cx.update(|cx| connection.clone().close_session(&session_id, cx))
+ .await
+ .unwrap();
+ drop(thread);
+ drop(acp_thread);
+ agent.read_with(cx, |agent, _| {
+ assert!(agent.sessions.is_empty());
+ });
+
+ let first_loaded_thread = cx.update(|cx| {
+ connection.clone().load_session(
+ session_id.clone(),
+ project.clone(),
+ PathList::new(&[Path::new("")]),
+ None,
+ cx,
+ )
+ });
+ let second_loaded_thread = cx.update(|cx| {
+ connection.clone().load_session(
+ session_id.clone(),
+ project.clone(),
+ PathList::new(&[Path::new("")]),
+ None,
+ cx,
+ )
+ });
+
+ let first_loaded_thread = first_loaded_thread.await.unwrap();
+ let second_loaded_thread = second_loaded_thread.await.unwrap();
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ first_loaded_thread.entity_id(),
+ second_loaded_thread.entity_id(),
+ "concurrent loads for the same session should share one AcpThread"
+ );
+
+ cx.update(|cx| connection.clone().close_session(&session_id, cx))
+ .await
+ .unwrap();
+
+ agent.read_with(cx, |agent, _| {
+ assert!(
+ agent.sessions.contains_key(&session_id),
+ "closing one loaded session should not drop shared session state"
+ );
+ });
+
+ let follow_up = second_loaded_thread.update(cx, |thread, cx| {
+ thread.send(vec!["still there?".into()], cx)
+ });
+ let follow_up = cx.foreground_executor().spawn(follow_up);
+ cx.run_until_parked();
+
+ fake_model.send_last_completion_stream_text_chunk("yes");
+ fake_model.end_last_completion_stream();
+ follow_up.await.unwrap();
+ cx.run_until_parked();
+
+ second_loaded_thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ formatdoc! {"
+ ## User
+
+ hello
+
+ ## Assistant
+
+ world
+
+ ## User
+
+ still there?
+
+ ## Assistant
+
+ yes
+
+ "}
+ );
+ });
+
+ cx.update(|cx| connection.clone().close_session(&session_id, cx))
+ .await
+ .unwrap();
+
+ cx.run_until_parked();
+
+ drop(first_loaded_thread);
+ drop(second_loaded_thread);
+ agent.read_with(cx, |agent, _| {
+ assert!(agent.sessions.is_empty());
+ });
+ }
+
#[gpui::test]
async fn test_rapid_title_changes_do_not_loop(cx: &mut TestAppContext) {
// Regression test: rapid title changes must not cause a propagation loop