@@ -33,16 +33,17 @@ pub struct AgentConnection(pub Entity<Agent>);
impl ModelSelector for AgentConnection {
fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
- let result = cx.update(|cx| {
- let registry = LanguageModelRegistry::read_global(cx);
- let models = registry.available_models(cx).collect::<Vec<_>>();
- if models.is_empty() {
- Err(anyhow::anyhow!("No models available"))
- } else {
- Ok(models)
- }
- });
- Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
+ cx.spawn(async move |cx| {
+ cx.update(|cx| {
+ let registry = LanguageModelRegistry::read_global(cx);
+ let models = registry.available_models(cx).collect::<Vec<_>>();
+ if models.is_empty() {
+ Err(anyhow::anyhow!("No models available"))
+ } else {
+ Ok(models)
+ }
+ })?
+ })
}
fn select_model(
@@ -52,17 +53,19 @@ impl ModelSelector for AgentConnection {
cx: &mut AsyncApp,
) -> Task<Result<()>> {
let agent = self.0.clone();
- let result = agent.update(cx, |agent, cx| {
- if let Some(thread) = agent.sessions.get(session_id) {
- thread.update(cx, |thread, _| {
- thread.selected_model = model;
- });
- Ok(())
- } else {
- Err(anyhow::anyhow!("Session not found"))
- }
- });
- Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
+ let session_id = session_id.clone();
+ cx.spawn(async move |cx| {
+ agent.update(cx, |agent, cx| {
+ if let Some(thread) = agent.sessions.get(&session_id) {
+ thread.update(cx, |thread, _| {
+ thread.selected_model = model;
+ });
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!("Session not found"))
+ }
+ })?
+ })
}
fn selected_model(
@@ -71,21 +74,14 @@ impl ModelSelector for AgentConnection {
cx: &mut AsyncApp,
) -> Task<Result<Arc<dyn LanguageModel>>> {
let agent = self.0.clone();
- let thread_result = agent
- .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
- .ok()
- .flatten()
- .ok_or_else(|| anyhow::anyhow!("Session not found"));
-
- match thread_result {
- Ok(thread) => {
- let selected = thread
- .read_with(cx, |thread, _| thread.selected_model.clone())
- .unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
- Task::ready(Ok(selected))
- }
- Err(e) => Task::ready(Err(e)),
- }
+ let session_id = session_id.clone();
+ cx.spawn(async move |cx| {
+ let thread = agent
+ .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
+ .ok_or_else(|| anyhow::anyhow!("Session not found"))?;
+ let selected = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
+ Ok(selected)
+ })
}
}
@@ -1,16 +1,19 @@
use super::*;
use crate::templates::Templates;
+use acp_thread::AgentConnection as _;
+use agent_client_protocol as acp;
use client::{Client, UserStore};
use gpui::{AppContext, Entity, TestAppContext};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRegistry, MessageContent, StopReason,
};
+use project::Project;
use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use smol::stream::StreamExt;
-use std::{sync::Arc, time::Duration};
+use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
mod test_tools;
use test_tools::*;
@@ -187,6 +190,129 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_agent_connection(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+ cx.update(settings::init);
+ let templates = Templates::new();
+
+ // Initialize language model system with test provider
+ cx.update(|cx| {
+ gpui_tokio::init(cx);
+ let http_client = ReqwestClient::user_agent("agent tests").unwrap();
+ cx.set_http_client(Arc::new(http_client));
+
+ client::init_settings(cx);
+ let client = Client::production(cx);
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ language_model::init(client.clone(), cx);
+ language_models::init(user_store.clone(), client.clone(), cx);
+
+ // Initialize project settings
+ Project::init_settings(cx);
+
+ // Use test registry with fake provider
+ LanguageModelRegistry::test(cx);
+ });
+
+ // Create agent and connection
+ let agent = cx.new(|_| Agent::new(templates.clone()));
+ let connection = AgentConnection(agent.clone());
+
+ // Test model_selector returns Some
+ let selector_opt = connection.model_selector();
+ assert!(
+ selector_opt.is_some(),
+ "agent2 should always support ModelSelector"
+ );
+ let selector = selector_opt.unwrap();
+
+ // Test list_models
+ let listed_models = cx
+ .update(|cx| {
+ let mut async_cx = cx.to_async();
+ selector.list_models(&mut async_cx)
+ })
+ .await
+ .expect("list_models should succeed");
+ assert!(!listed_models.is_empty(), "should have at least one model");
+ assert_eq!(listed_models[0].id().0, "fake");
+
+ // Create a project for new_thread
+ let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
+ let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
+
+ // Create a thread using new_thread
+ let cwd = Path::new("/test");
+ let connection_rc = Rc::new(connection.clone());
+ let acp_thread = cx
+ .update(|cx| {
+ let mut async_cx = cx.to_async();
+ connection_rc.new_thread(project, cwd, &mut async_cx)
+ })
+ .await
+ .expect("new_thread should succeed");
+
+ // Get the session_id from the AcpThread
+ let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+
+ // Test selected_model returns the default
+ let selected = cx
+ .update(|cx| {
+ let mut async_cx = cx.to_async();
+ selector.selected_model(&session_id, &mut async_cx)
+ })
+ .await
+ .expect("selected_model should succeed");
+ assert_eq!(selected.id().0, "fake", "should return default model");
+
+ // The thread was created via prompt with the default model
+ // We can verify it through selected_model
+
+ // Test prompt uses the selected model
+ let prompt_request = acp::PromptRequest {
+ session_id: session_id.clone(),
+ prompt: vec![acp::ContentBlock::Text(acp::TextContent {
+ text: "Test prompt".into(),
+ annotations: None,
+ })],
+ };
+
+ cx.update(|cx| connection.prompt(prompt_request, cx))
+ .await
+ .expect("prompt should succeed");
+
+ // The prompt was sent successfully
+
+ // Test cancel
+ cx.update(|cx| connection.cancel(&session_id, cx));
+
+ // After cancel, selected_model should fail
+ let result = cx
+ .update(|cx| {
+ let mut async_cx = cx.to_async();
+ selector.selected_model(&session_id, &mut async_cx)
+ })
+ .await;
+ assert!(result.is_err(), "selected_model should fail after cancel");
+
+ // Test error case: invalid session
+ let invalid_session = acp::SessionId("invalid".into());
+ let result = cx
+ .update(|cx| {
+ let mut async_cx = cx.to_async();
+ selector.selected_model(&invalid_session, &mut async_cx)
+ })
+ .await;
+ assert!(result.is_err(), "should fail for invalid session");
+ if let Err(e) = result {
+ assert!(
+ e.to_string().contains("Session not found"),
+ "should have correct error message"
+ );
+ }
+}
+
/// Filters out the stop events for asserting against in tests
fn stop_events(
result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,