Add ModelSelector capability to AgentConnection

Nathan Sobo created

- Add ModelSelector trait to acp_thread crate with list_models, select_model, and selected_model methods
- Extend AgentConnection trait with optional model_selector() method returning Option<Rc<dyn ModelSelector>>
- Implement ModelSelector for agent2's AgentConnection using LanguageModelRegistry
- Make selected_model field mandatory on Thread struct
- Update Thread::new to require a default_model parameter
- Update agent2 to fetch default model from registry when creating threads
- Fix prompt method to use the thread's selected model directly
- All methods use &mut AsyncApp for async-friendly operations

Change summary

Cargo.lock                          |  1 
crates/acp_thread/Cargo.toml        |  1 
crates/acp_thread/src/connection.rs | 58 ++++++++++++++++++++
crates/agent2/src/agent.rs          | 88 ++++++++++++++++++++++++++++--
crates/agent2/src/tests/mod.rs      |  3 
crates/agent2/src/thread.rs         |  4 +
6 files changed, 144 insertions(+), 11 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -19,6 +19,7 @@ dependencies = [
  "indoc",
  "itertools 0.14.0",
  "language",
+ "language_model",
  "markdown",
  "project",
  "serde",

crates/acp_thread/Cargo.toml 🔗

@@ -26,6 +26,7 @@ futures.workspace = true
 gpui.workspace = true
 itertools.workspace = true
 language.workspace = true
+language_model.workspace = true
 markdown.workspace = true
 project.workspace = true
 serde.workspace = true

crates/acp_thread/src/connection.rs 🔗

@@ -1,13 +1,61 @@
-use std::{error::Error, fmt, path::Path, rc::Rc};
+use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 
 use agent_client_protocol::{self as acp};
 use anyhow::Result;
 use gpui::{AsyncApp, Entity, Task};
+use language_model::LanguageModel;
 use project::Project;
 use ui::App;
 
 use crate::AcpThread;
 
+/// Trait for agents that support listing, selecting, and querying language models.
+///
+/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
+pub trait ModelSelector: 'static {
+    /// Lists all available language models for this agent.
+    ///
+    /// # Parameters
+    /// - `cx`: The GPUI app context for async operations and global access.
+    ///
+    /// # Returns
+    /// A task resolving to the list of models or an error (e.g., if no models are configured).
+    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
+
+    /// Selects a model for a specific session (thread).
+    ///
+    /// This sets the default model for future interactions in the session.
+    /// If the session doesn't exist or the model is invalid, it returns an error.
+    ///
+    /// # Parameters
+    /// - `session_id`: The ID of the session (thread) to apply the model to.
+    /// - `model`: The model to select (should be one from [list_models]).
+    /// - `cx`: The GPUI app context.
+    ///
+    /// # Returns
+    /// A task resolving to `Ok(())` on success or an error.
+    fn select_model(
+        &self,
+        session_id: &acp::SessionId,
+        model: Arc<dyn LanguageModel>,
+        cx: &mut AsyncApp,
+    ) -> Task<Result<()>>;
+
+    /// Retrieves the currently selected model for a specific session (thread).
+    ///
+    /// # Parameters
+    /// - `session_id`: The ID of the session (thread) to query.
+    /// - `cx`: The GPUI app context.
+    ///
+    /// # Returns
+    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
+    fn selected_model(
+        &self,
+        session_id: &acp::SessionId,
+        cx: &mut AsyncApp,
+    ) -> Task<Result<Arc<dyn LanguageModel>>>;
+}
+
 pub trait AgentConnection {
     fn new_thread(
         self: Rc<Self>,
@@ -23,6 +71,14 @@ pub trait AgentConnection {
     fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
 
     fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
+
+    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
+    ///
+    /// If the agent does not support model selection, returns [None].
+    /// This allows sharing the selector in UI components.
+    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
+        None // Default impl for agents that don't support it
+    }
 }
 
 #[derive(Debug)]

crates/agent2/src/agent.rs 🔗

@@ -1,6 +1,8 @@
+use acp_thread::ModelSelector;
 use agent_client_protocol as acp;
 use anyhow::Result;
 use gpui::{App, AppContext, AsyncApp, Entity, Task};
+use language_model::{LanguageModel, LanguageModelRegistry};
 use project::Project;
 use std::collections::HashMap;
 use std::path::Path;
@@ -26,8 +28,67 @@ impl Agent {
 }
 
 /// Wrapper struct that implements the AgentConnection trait
+#[derive(Clone)]
 pub struct AgentConnection(pub Entity<Agent>);
 
+impl ModelSelector for AgentConnection {
+    fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
+        let result = cx.update(|cx| {
+            let registry = LanguageModelRegistry::read_global(cx);
+            let models = registry.available_models(cx).collect::<Vec<_>>();
+            if models.is_empty() {
+                Err(anyhow::anyhow!("No models available"))
+            } else {
+                Ok(models)
+            }
+        });
+        Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
+    }
+
+    fn select_model(
+        &self,
+        session_id: &acp::SessionId,
+        model: Arc<dyn LanguageModel>,
+        cx: &mut AsyncApp,
+    ) -> Task<Result<()>> {
+        let agent = self.0.clone();
+        let result = agent.update(cx, |agent, cx| {
+            if let Some(thread) = agent.sessions.get(session_id) {
+                thread.update(cx, |thread, _| {
+                    thread.selected_model = model;
+                });
+                Ok(())
+            } else {
+                Err(anyhow::anyhow!("Session not found"))
+            }
+        });
+        Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
+    }
+
+    fn selected_model(
+        &self,
+        session_id: &acp::SessionId,
+        cx: &mut AsyncApp,
+    ) -> Task<Result<Arc<dyn LanguageModel>>> {
+        let agent = self.0.clone();
+        let thread_result = agent
+            .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
+            .ok()
+            .flatten()
+            .ok_or_else(|| anyhow::anyhow!("Session not found"));
+
+        match thread_result {
+            Ok(thread) => {
+                let selected = thread
+                    .read_with(cx, |thread, _| thread.selected_model.clone())
+                    .unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
+                Task::ready(Ok(selected))
+            }
+            Err(e) => Task::ready(Err(e)),
+        }
+    }
+}
+
 impl acp_thread::AgentConnection for AgentConnection {
     fn new_thread(
         self: Rc<Self>,
@@ -42,7 +103,13 @@ impl acp_thread::AgentConnection for AgentConnection {
             // Create Thread and store in Agent
             let (session_id, _thread) =
                 agent.update(cx, |agent, cx: &mut gpui::Context<Agent>| {
-                    let thread = cx.new(|_| Thread::new(agent.templates.clone()));
+                    // Fetch default model
+                    let default_model = LanguageModelRegistry::read_global(cx)
+                        .available_models(cx)
+                        .next()
+                        .unwrap_or_else(|| panic!("No default model available"));
+
+                    let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
                     let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
                     agent.sessions.insert(session_id.clone(), thread.clone());
                     (session_id, thread)
@@ -50,7 +117,9 @@ impl acp_thread::AgentConnection for AgentConnection {
 
             // Create AcpThread
             let acp_thread = cx.update(|cx| {
-                cx.new(|cx| acp_thread::AcpThread::new("agent2", self, project, session_id, cx))
+                cx.new(|cx| {
+                    acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx)
+                })
             })?;
 
             Ok(acp_thread)
@@ -65,11 +134,15 @@ impl acp_thread::AgentConnection for AgentConnection {
         Task::ready(Ok(()))
     }
 
+    fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
+        Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
+    }
+
     fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
         let session_id = params.session_id.clone();
         let agent = self.0.clone();
 
-        cx.spawn(|cx| async move {
+        cx.spawn(async move |cx| {
             // Get thread
             let thread: Entity<Thread> = agent
                 .read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
@@ -78,13 +151,12 @@ impl acp_thread::AgentConnection for AgentConnection {
             // Convert prompt to message
             let message = convert_prompt_to_message(params.prompt);
 
-            // TODO: Get model from somewhere - for now use a placeholder
-            log::warn!("Model selection not implemented - need to get from UI context");
+            // Get model using the ModelSelector capability (always available for agent2)
+            // Get the selected model from the thread directly
+            let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
 
             // Send to thread
-            // thread.update(&mut cx, |thread, cx| {
-            //     thread.send(model, message, cx)
-            // })?;
+            thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
 
             Ok(())
         })

crates/agent2/src/tests/mod.rs 🔗

@@ -209,7 +209,6 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest {
     cx.executor().allow_parking();
     cx.update(settings::init);
     let templates = Templates::new();
-    let thread = cx.new(|_| Thread::new(templates));
 
     let model = cx
         .update(|cx| {
@@ -239,6 +238,8 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest {
         })
         .await;
 
+    let thread = cx.new(|_| Thread::new(templates, model.clone()));
+
     ThreadTest { model, thread }
 }
 

crates/agent2/src/thread.rs 🔗

@@ -37,12 +37,13 @@ pub struct Thread {
     system_prompts: Vec<Arc<dyn Prompt>>,
     tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
     templates: Arc<Templates>,
+    pub selected_model: Arc<dyn LanguageModel>,
     // project: Entity<Project>,
     // action_log: Entity<ActionLog>,
 }
 
 impl Thread {
-    pub fn new(templates: Arc<Templates>) -> Self {
+    pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
         Self {
             messages: Vec::new(),
             completion_mode: CompletionMode::Normal,
@@ -50,6 +51,7 @@ impl Thread {
             running_turn: None,
             tools: BTreeMap::default(),
             templates,
+            selected_model: default_model,
         }
     }