From 604a88f6e35697c1219c27ee92e7d7fa7169741e Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Sat, 2 Aug 2025 08:45:51 -0600 Subject: [PATCH] Add comprehensive test for AgentConnection with ModelSelector - Add public session_id() method to AcpThread to enable testing - Fix ModelSelector methods to use async move closures properly to avoid borrow conflicts - Add test_agent_connection that verifies: - Model selector is available for agent2 - Can list available models - Can create threads with default model - Can query selected model for a session - Can send prompts using the selected model - Can cancel sessions - Handles errors for invalid sessions - Remove unnecessary mut keywords from async closures --- crates/acp_thread/src/acp_thread.rs | 4 + crates/agent2/src/agent.rs | 68 +++++++-------- crates/agent2/src/tests/mod.rs | 128 +++++++++++++++++++++++++++- 3 files changed, 163 insertions(+), 37 deletions(-) diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index d10fecdb286eab0c06fb106062cbb9fa38430f87..c42c155e89023502f86666554d46cd3b7a4819bc 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -656,6 +656,10 @@ impl AcpThread { &self.entries } + pub fn session_id(&self) -> &acp::SessionId { + &self.session_id + } + pub fn status(&self) -> ThreadStatus { if self.send_task.is_some() { if self.waiting_for_tool_confirmation() { diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index f10738313ea9bff0ba62c5d9b181a469a3d25143..abd23de37594028ac263d4cb7a0919798441d376 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -33,16 +33,17 @@ pub struct AgentConnection(pub Entity); impl ModelSelector for AgentConnection { fn list_models(&self, cx: &mut AsyncApp) -> Task>>> { - let result = cx.update(|cx| { - let registry = LanguageModelRegistry::read_global(cx); - let models = registry.available_models(cx).collect::>(); - if models.is_empty() { - Err(anyhow::anyhow!("No models available")) - } else { - Ok(models) - } - }); - Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e)))) + cx.spawn(async move |cx| { + cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let models = registry.available_models(cx).collect::>(); + if models.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(models) + } + })? + }) } fn select_model( @@ -52,17 +53,19 @@ impl ModelSelector for AgentConnection { cx: &mut AsyncApp, ) -> Task> { let agent = self.0.clone(); - let result = agent.update(cx, |agent, cx| { - if let Some(thread) = agent.sessions.get(session_id) { - thread.update(cx, |thread, _| { - thread.selected_model = model; - }); - Ok(()) - } else { - Err(anyhow::anyhow!("Session not found")) - } - }); - Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e)))) + let session_id = session_id.clone(); + cx.spawn(async move |cx| { + agent.update(cx, |agent, cx| { + if let Some(thread) = agent.sessions.get(&session_id) { + thread.update(cx, |thread, _| { + thread.selected_model = model; + }); + Ok(()) + } else { + Err(anyhow::anyhow!("Session not found")) + } + })? + }) } fn selected_model( @@ -71,21 +74,14 @@ impl ModelSelector for AgentConnection { cx: &mut AsyncApp, ) -> Task>> { let agent = self.0.clone(); - let thread_result = agent - .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned()) - .ok() - .flatten() - .ok_or_else(|| anyhow::anyhow!("Session not found")); - - match thread_result { - Ok(thread) => { - let selected = thread - .read_with(cx, |thread, _| thread.selected_model.clone()) - .unwrap_or_else(|e| panic!("Failed to read thread: {}", e)); - Task::ready(Ok(selected)) - } - Err(e) => Task::ready(Err(e)), - } + let session_id = session_id.clone(); + cx.spawn(async move |cx| { + let thread = agent + .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())? + .ok_or_else(|| anyhow::anyhow!("Session not found"))?; + let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?; + Ok(selected) + }) } } diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index f7dc9055f69afcef857f868a0cf0e1a824503424..a628658b221e132954fafa90fe0c458b37fcea45 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1,16 +1,19 @@ use super::*; use crate::templates::Templates; +use acp_thread::AgentConnection as _; +use agent_client_protocol as acp; use client::{Client, UserStore}; use gpui::{AppContext, Entity, TestAppContext}; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry, MessageContent, StopReason, }; +use project::Project; use reqwest_client::ReqwestClient; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use smol::stream::StreamExt; -use std::{sync::Arc, time::Duration}; +use std::{path::Path, rc::Rc, sync::Arc, time::Duration}; mod test_tools; use test_tools::*; @@ -187,6 +190,129 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) { }); } +#[gpui::test] +async fn test_agent_connection(cx: &mut TestAppContext) { + cx.executor().allow_parking(); + cx.update(settings::init); + let templates = Templates::new(); + + // Initialize language model system with test provider + cx.update(|cx| { + gpui_tokio::init(cx); + let http_client = ReqwestClient::user_agent("agent tests").unwrap(); + cx.set_http_client(Arc::new(http_client)); + + client::init_settings(cx); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(client.clone(), cx); + language_models::init(user_store.clone(), client.clone(), cx); + + // Initialize project settings + Project::init_settings(cx); + + // Use test registry with fake provider + LanguageModelRegistry::test(cx); + }); + + // Create agent and connection + let agent = cx.new(|_| Agent::new(templates.clone())); + let connection = AgentConnection(agent.clone()); + + // Test model_selector returns Some + let selector_opt = connection.model_selector(); + assert!( + selector_opt.is_some(), + "agent2 should always support ModelSelector" + ); + let selector = selector_opt.unwrap(); + + // Test list_models + let listed_models = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.list_models(&mut async_cx) + }) + .await + .expect("list_models should succeed"); + assert!(!listed_models.is_empty(), "should have at least one model"); + assert_eq!(listed_models[0].id().0, "fake"); + + // Create a project for new_thread + let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone())); + let project = Project::test(fake_fs, [Path::new("/test")], cx).await; + + // Create a thread using new_thread + let cwd = Path::new("/test"); + let connection_rc = Rc::new(connection.clone()); + let acp_thread = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + connection_rc.new_thread(project, cwd, &mut async_cx) + }) + .await + .expect("new_thread should succeed"); + + // Get the session_id from the AcpThread + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + + // Test selected_model returns the default + let selected = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&session_id, &mut async_cx) + }) + .await + .expect("selected_model should succeed"); + assert_eq!(selected.id().0, "fake", "should return default model"); + + // The thread was created via prompt with the default model + // We can verify it through selected_model + + // Test prompt uses the selected model + let prompt_request = acp::PromptRequest { + session_id: session_id.clone(), + prompt: vec![acp::ContentBlock::Text(acp::TextContent { + text: "Test prompt".into(), + annotations: None, + })], + }; + + cx.update(|cx| connection.prompt(prompt_request, cx)) + .await + .expect("prompt should succeed"); + + // The prompt was sent successfully + + // Test cancel + cx.update(|cx| connection.cancel(&session_id, cx)); + + // After cancel, selected_model should fail + let result = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&session_id, &mut async_cx) + }) + .await; + assert!(result.is_err(), "selected_model should fail after cancel"); + + // Test error case: invalid session + let invalid_session = acp::SessionId("invalid".into()); + let result = cx + .update(|cx| { + let mut async_cx = cx.to_async(); + selector.selected_model(&invalid_session, &mut async_cx) + }) + .await; + assert!(result.is_err(), "should fail for invalid session"); + if let Err(e) = result { + assert!( + e.to_string().contains("Session not found"), + "should have correct error message" + ); + } +} + /// Filters out the stop events for asserting against in tests fn stop_events( result_events: Vec>,