Drop native agent session when `AcpThread` gets released (#35713)

Antonio Scandurra and Ben Brandt created

Release Notes:

- N/A

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

Cargo.lock                     |  1 
crates/agent2/Cargo.toml       |  1 
crates/agent2/src/agent.rs     | 12 +++-
crates/agent2/src/tests/mod.rs | 75 +++++++++++++++++++++++------------
4 files changed, 58 insertions(+), 31 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -160,6 +160,7 @@ dependencies = [
  "agent_servers",
  "anyhow",
  "client",
+ "clock",
  "cloud_llm_client",
  "collections",
  "ctor",

crates/agent2/Cargo.toml 🔗

@@ -42,6 +42,7 @@ workspace-hack.workspace = true
 [dev-dependencies]
 ctor.workspace = true
 client = { workspace = true, "features" = ["test-support"] }
+clock = { workspace = true, "features" = ["test-support"] }
 env_logger.workspace = true
 fs = { workspace = true, "features" = ["test-support"] }
 gpui = { workspace = true, "features" = ["test-support"] }

crates/agent2/src/agent.rs 🔗

@@ -2,7 +2,7 @@ use acp_thread::ModelSelector;
 use agent_client_protocol as acp;
 use anyhow::{anyhow, Result};
 use futures::StreamExt;
-use gpui::{App, AppContext, AsyncApp, Entity, Task};
+use gpui::{App, AppContext, AsyncApp, Entity, Subscription, Task, WeakEntity};
 use language_model::{LanguageModel, LanguageModelRegistry};
 use project::Project;
 use std::collections::HashMap;
@@ -17,7 +17,8 @@ struct Session {
     /// The internal thread that processes messages
     thread: Entity<Thread>,
     /// The ACP thread that handles protocol communication
-    acp_thread: Entity<acp_thread::AcpThread>,
+    acp_thread: WeakEntity<acp_thread::AcpThread>,
+    _subscription: Subscription,
 }
 
 pub struct NativeAgent {
@@ -162,12 +163,15 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
             })?;
 
             // Store the session
-            agent.update(cx, |agent, _cx| {
+            agent.update(cx, |agent, cx| {
                 agent.sessions.insert(
                     session_id,
                     Session {
                         thread,
-                        acp_thread: acp_thread.clone(),
+                        acp_thread: acp_thread.downgrade(),
+                        _subscription: cx.observe_release(&acp_thread, |this, acp_thread, _cx| {
+                            this.sessions.remove(acp_thread.session_id());
+                        })
                     },
                 );
             })?;

crates/agent2/src/tests/mod.rs 🔗

@@ -1,10 +1,10 @@
 use super::*;
 use crate::templates::Templates;
-use acp_thread::AgentConnection as _;
+use acp_thread::AgentConnection;
 use agent_client_protocol as acp;
 use client::{Client, UserStore};
 use fs::FakeFs;
-use gpui::{AppContext, Entity, Task, TestAppContext};
+use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
 use indoc::indoc;
 use language_model::{
     fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
@@ -322,31 +322,26 @@ async fn test_refusal(cx: &mut TestAppContext) {
     });
 }
 
-#[ignore = "temporarily disabled until it can be run on CI"]
 #[gpui::test]
 async fn test_agent_connection(cx: &mut TestAppContext) {
-    cx.executor().allow_parking();
     cx.update(settings::init);
     let templates = Templates::new();
 
     // Initialize language model system with test provider
     cx.update(|cx| {
         gpui_tokio::init(cx);
-        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
-        cx.set_http_client(Arc::new(http_client));
-
         client::init_settings(cx);
-        let client = Client::production(cx);
+
+        let http_client = FakeHttpClient::with_404_response();
+        let clock = Arc::new(clock::FakeSystemClock::new());
+        let client = Client::new(clock, http_client, cx);
         let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
         language_model::init(client.clone(), cx);
         language_models::init(user_store.clone(), client.clone(), cx);
-
-        // Initialize project settings
         Project::init_settings(cx);
-
-        // Use test registry with fake provider
         LanguageModelRegistry::test(cx);
     });
+    cx.executor().forbid_parking();
 
     // Create agent and connection
     let agent = cx.new(|_| NativeAgent::new(templates.clone()));
@@ -390,34 +385,60 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
     let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
 
     // Test selected_model returns the default
-    let selected = cx
+    let model = cx
         .update(|cx| {
             let mut async_cx = cx.to_async();
             selector.selected_model(&session_id, &mut async_cx)
         })
         .await
         .expect("selected_model should succeed");
-    assert_eq!(selected.id().0, "fake", "should return default model");
+    let model = model.as_fake();
+    assert_eq!(model.id().0, "fake", "should return default model");
+
+    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
+    cx.run_until_parked();
+    model.send_last_completion_stream_text_chunk("def");
+    cx.run_until_parked();
+    acp_thread.read_with(cx, |thread, cx| {
+        assert_eq!(
+            thread.to_markdown(cx),
+            indoc! {"
+                ## User
 
-    // The thread was created via prompt with the default model
-    // We can verify it through selected_model
+                abc
 
-    // Test prompt uses the selected model
-    let prompt_request = acp::PromptRequest {
-        session_id: session_id.clone(),
-        prompt: vec![acp::ContentBlock::Text(acp::TextContent {
-            text: "Test prompt".into(),
-            annotations: None,
-        })],
-    };
+                ## Assistant
 
-    let request = cx.update(|cx| connection.prompt(prompt_request, cx));
-    let request = cx.background_spawn(request);
-    smol::Timer::after(Duration::from_millis(100)).await;
+                def
+
+            "}
+        )
+    });
 
     // Test cancel
     cx.update(|cx| connection.cancel(&session_id, cx));
     request.await.expect("prompt should fail gracefully");
+
+    // Ensure that dropping the ACP thread causes the native thread to be
+    // dropped as well.
+    cx.update(|_| drop(acp_thread));
+    let result = cx
+        .update(|cx| {
+            connection.prompt(
+                acp::PromptRequest {
+                    session_id: session_id.clone(),
+                    prompt: vec!["ghi".into()],
+                },
+                cx,
+            )
+        })
+        .await;
+    assert_eq!(
+        result.as_ref().unwrap_err().to_string(),
+        "Session not found",
+        "unexpected result: {:?}",
+        result
+    );
 }
 
 /// Filters out the stop events for asserting against in tests