Add comprehensive test for AgentConnection with ModelSelector

Nathan Sobo created

- Add public session_id() method to AcpThread to enable testing
- Fix ModelSelector methods to use async move closures properly to avoid borrow conflicts
- Add test_agent_connection that verifies:
  - Model selector is available for agent2
  - Can list available models
  - Can create threads with default model
  - Can query selected model for a session
  - Can send prompts using the selected model
  - Can cancel sessions
  - Handles errors for invalid sessions
- Remove unnecessary mut keywords from async closures

Change summary

crates/acp_thread/src/acp_thread.rs |   4 
crates/agent2/src/agent.rs          |  68 +++++++--------
crates/agent2/src/tests/mod.rs      | 128 ++++++++++++++++++++++++++++++
3 files changed, 163 insertions(+), 37 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -656,6 +656,10 @@ impl AcpThread {
         &self.entries
     }
 
+    pub fn session_id(&self) -> &acp::SessionId {
+        &self.session_id
+    }
+
     pub fn status(&self) -> ThreadStatus {
         if self.send_task.is_some() {
             if self.waiting_for_tool_confirmation() {

crates/agent2/src/agent.rs 🔗

@@ -33,16 +33,17 @@ pub struct AgentConnection(pub Entity<Agent>);
 
 impl ModelSelector for AgentConnection {
     fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
-        let result = cx.update(|cx| {
-            let registry = LanguageModelRegistry::read_global(cx);
-            let models = registry.available_models(cx).collect::<Vec<_>>();
-            if models.is_empty() {
-                Err(anyhow::anyhow!("No models available"))
-            } else {
-                Ok(models)
-            }
-        });
-        Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
+        cx.spawn(async move |cx| {
+            cx.update(|cx| {
+                let registry = LanguageModelRegistry::read_global(cx);
+                let models = registry.available_models(cx).collect::<Vec<_>>();
+                if models.is_empty() {
+                    Err(anyhow::anyhow!("No models available"))
+                } else {
+                    Ok(models)
+                }
+            })?
+        })
     }
 
     fn select_model(
@@ -52,17 +53,19 @@ impl ModelSelector for AgentConnection {
         cx: &mut AsyncApp,
     ) -> Task<Result<()>> {
         let agent = self.0.clone();
-        let result = agent.update(cx, |agent, cx| {
-            if let Some(thread) = agent.sessions.get(session_id) {
-                thread.update(cx, |thread, _| {
-                    thread.selected_model = model;
-                });
-                Ok(())
-            } else {
-                Err(anyhow::anyhow!("Session not found"))
-            }
-        });
-        Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
+        let session_id = session_id.clone();
+        cx.spawn(async move |cx| {
+            agent.update(cx, |agent, cx| {
+                if let Some(thread) = agent.sessions.get(&session_id) {
+                    thread.update(cx, |thread, _| {
+                        thread.selected_model = model;
+                    });
+                    Ok(())
+                } else {
+                    Err(anyhow::anyhow!("Session not found"))
+                }
+            })?
+        })
     }
 
     fn selected_model(
@@ -71,21 +74,14 @@ impl ModelSelector for AgentConnection {
         cx: &mut AsyncApp,
     ) -> Task<Result<Arc<dyn LanguageModel>>> {
         let agent = self.0.clone();
-        let thread_result = agent
-            .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
-            .ok()
-            .flatten()
-            .ok_or_else(|| anyhow::anyhow!("Session not found"));
-
-        match thread_result {
-            Ok(thread) => {
-                let selected = thread
-                    .read_with(cx, |thread, _| thread.selected_model.clone())
-                    .unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
-                Task::ready(Ok(selected))
-            }
-            Err(e) => Task::ready(Err(e)),
-        }
+        let session_id = session_id.clone();
+        cx.spawn(async move |cx| {
+            let thread = agent
+                .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
+                .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
+            let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
+            Ok(selected)
+        })
     }
 }
 

crates/agent2/src/tests/mod.rs 🔗

@@ -1,16 +1,19 @@
 use super::*;
 use crate::templates::Templates;
+use acp_thread::AgentConnection as _;
+use agent_client_protocol as acp;
 use client::{Client, UserStore};
 use gpui::{AppContext, Entity, TestAppContext};
 use language_model::{
     LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
     LanguageModelRegistry, MessageContent, StopReason,
 };
+use project::Project;
 use reqwest_client::ReqwestClient;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use smol::stream::StreamExt;
-use std::{sync::Arc, time::Duration};
+use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
 
 mod test_tools;
 use test_tools::*;
@@ -187,6 +190,129 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
     });
 }
 
+#[gpui::test]
+async fn test_agent_connection(cx: &mut TestAppContext) {
+    cx.executor().allow_parking();
+    cx.update(settings::init);
+    let templates = Templates::new();
+
+    // Initialize language model system with test provider
+    cx.update(|cx| {
+        gpui_tokio::init(cx);
+        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
+        cx.set_http_client(Arc::new(http_client));
+
+        client::init_settings(cx);
+        let client = Client::production(cx);
+        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+        language_model::init(client.clone(), cx);
+        language_models::init(user_store.clone(), client.clone(), cx);
+
+        // Initialize project settings
+        Project::init_settings(cx);
+
+        // Use test registry with fake provider
+        LanguageModelRegistry::test(cx);
+    });
+
+    // Create agent and connection
+    let agent = cx.new(|_| Agent::new(templates.clone()));
+    let connection = AgentConnection(agent.clone());
+
+    // Test model_selector returns Some
+    let selector_opt = connection.model_selector();
+    assert!(
+        selector_opt.is_some(),
+        "agent2 should always support ModelSelector"
+    );
+    let selector = selector_opt.unwrap();
+
+    // Test list_models
+    let listed_models = cx
+        .update(|cx| {
+            let mut async_cx = cx.to_async();
+            selector.list_models(&mut async_cx)
+        })
+        .await
+        .expect("list_models should succeed");
+    assert!(!listed_models.is_empty(), "should have at least one model");
+    assert_eq!(listed_models[0].id().0, "fake");
+
+    // Create a project for new_thread
+    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
+    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
+
+    // Create a thread using new_thread
+    let cwd = Path::new("/test");
+    let connection_rc = Rc::new(connection.clone());
+    let acp_thread = cx
+        .update(|cx| {
+            let mut async_cx = cx.to_async();
+            connection_rc.new_thread(project, cwd, &mut async_cx)
+        })
+        .await
+        .expect("new_thread should succeed");
+
+    // Get the session_id from the AcpThread
+    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+
+    // Test selected_model returns the default
+    let selected = cx
+        .update(|cx| {
+            let mut async_cx = cx.to_async();
+            selector.selected_model(&session_id, &mut async_cx)
+        })
+        .await
+        .expect("selected_model should succeed");
+    assert_eq!(selected.id().0, "fake", "should return default model");
+
+    // The thread was created via prompt with the default model
+    // We can verify it through selected_model
+
+    // Test prompt uses the selected model
+    let prompt_request = acp::PromptRequest {
+        session_id: session_id.clone(),
+        prompt: vec![acp::ContentBlock::Text(acp::TextContent {
+            text: "Test prompt".into(),
+            annotations: None,
+        })],
+    };
+
+    cx.update(|cx| connection.prompt(prompt_request, cx))
+        .await
+        .expect("prompt should succeed");
+
+    // The prompt was sent successfully
+
+    // Test cancel
+    cx.update(|cx| connection.cancel(&session_id, cx));
+
+    // After cancel, selected_model should fail
+    let result = cx
+        .update(|cx| {
+            let mut async_cx = cx.to_async();
+            selector.selected_model(&session_id, &mut async_cx)
+        })
+        .await;
+    assert!(result.is_err(), "selected_model should fail after cancel");
+
+    // Test error case: invalid session
+    let invalid_session = acp::SessionId("invalid".into());
+    let result = cx
+        .update(|cx| {
+            let mut async_cx = cx.to_async();
+            selector.selected_model(&invalid_session, &mut async_cx)
+        })
+        .await;
+    assert!(result.is_err(), "should fail for invalid session");
+    if let Err(e) = result {
+        assert!(
+            e.to_string().contains("Session not found"),
+            "should have correct error message"
+        );
+    }
+}
+
 /// Filters out the stop events for asserting against in tests
 fn stop_events(
     result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,