Cargo.lock 🔗
@@ -19,6 +19,7 @@ dependencies = [
"indoc",
"itertools 0.14.0",
"language",
+ "language_model",
"markdown",
"project",
"serde",
Nathan Sobo created
- Add ModelSelector trait to acp_thread crate with list_models, select_model, and selected_model methods
- Extend AgentConnection trait with optional model_selector() method returning Option<Rc<dyn ModelSelector>>
- Implement ModelSelector for agent2's AgentConnection using LanguageModelRegistry
- Make selected_model field mandatory on Thread struct
- Update Thread::new to require a default_model parameter
- Update agent2 to fetch default model from registry when creating threads
- Fix prompt method to use the thread's selected model directly
- All methods use &mut AsyncApp for async-friendly operations
Cargo.lock | 1
crates/acp_thread/Cargo.toml | 1
crates/acp_thread/src/connection.rs | 58 ++++++++++++++++++++
crates/agent2/src/agent.rs | 88 ++++++++++++++++++++++++++++--
crates/agent2/src/tests/mod.rs | 3
crates/agent2/src/thread.rs | 4 +
6 files changed, 144 insertions(+), 11 deletions(-)
@@ -19,6 +19,7 @@ dependencies = [
"indoc",
"itertools 0.14.0",
"language",
+ "language_model",
"markdown",
"project",
"serde",
@@ -26,6 +26,7 @@ futures.workspace = true
gpui.workspace = true
itertools.workspace = true
language.workspace = true
+language_model.workspace = true
markdown.workspace = true
project.workspace = true
serde.workspace = true
@@ -1,13 +1,61 @@
-use std::{error::Error, fmt, path::Path, rc::Rc};
+use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use agent_client_protocol::{self as acp};
use anyhow::Result;
use gpui::{AsyncApp, Entity, Task};
+use language_model::LanguageModel;
use project::Project;
use ui::App;
use crate::AcpThread;
+/// Trait for agents that support listing, selecting, and querying language models.
+///
+/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
+pub trait ModelSelector: 'static {
+ /// Lists all available language models for this agent.
+ ///
+ /// # Parameters
+ /// - `cx`: The GPUI app context for async operations and global access.
+ ///
+ /// # Returns
+ /// A task resolving to the list of models or an error (e.g., if no models are configured).
+ fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>>;
+
+ /// Selects a model for a specific session (thread).
+ ///
+ /// This sets the default model for future interactions in the session.
+ /// If the session doesn't exist or the model is invalid, it returns an error.
+ ///
+ /// # Parameters
+ /// - `session_id`: The ID of the session (thread) to apply the model to.
+ /// - `model`: The model to select (should be one from [list_models]).
+ /// - `cx`: The GPUI app context.
+ ///
+ /// # Returns
+ /// A task resolving to `Ok(())` on success or an error.
+ fn select_model(
+ &self,
+ session_id: &acp::SessionId,
+ model: Arc<dyn LanguageModel>,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<()>>;
+
+ /// Retrieves the currently selected model for a specific session (thread).
+ ///
+ /// # Parameters
+ /// - `session_id`: The ID of the session (thread) to query.
+ /// - `cx`: The GPUI app context.
+ ///
+ /// # Returns
+ /// A task resolving to the selected model (always set) or an error (e.g., session not found).
+ fn selected_model(
+ &self,
+ session_id: &acp::SessionId,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<Arc<dyn LanguageModel>>>;
+}
+
pub trait AgentConnection {
fn new_thread(
self: Rc<Self>,
@@ -23,6 +71,14 @@ pub trait AgentConnection {
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
+
+ /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
+ ///
+ /// If the agent does not support model selection, returns [None].
+ /// This allows sharing the selector in UI components.
+ fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
+ None // Default impl for agents that don't support it
+ }
}
#[derive(Debug)]
@@ -1,6 +1,8 @@
+use acp_thread::ModelSelector;
use agent_client_protocol as acp;
use anyhow::Result;
use gpui::{App, AppContext, AsyncApp, Entity, Task};
+use language_model::{LanguageModel, LanguageModelRegistry};
use project::Project;
use std::collections::HashMap;
use std::path::Path;
@@ -26,8 +28,67 @@ impl Agent {
}
/// Wrapper struct that implements the AgentConnection trait
+#[derive(Clone)]
pub struct AgentConnection(pub Entity<Agent>);
+impl ModelSelector for AgentConnection {
+ fn list_models(&self, cx: &mut AsyncApp) -> Task<Result<Vec<Arc<dyn LanguageModel>>>> {
+ let result = cx.update(|cx| {
+ let registry = LanguageModelRegistry::read_global(cx);
+ let models = registry.available_models(cx).collect::<Vec<_>>();
+ if models.is_empty() {
+ Err(anyhow::anyhow!("No models available"))
+ } else {
+ Ok(models)
+ }
+ });
+ Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
+ }
+
+ fn select_model(
+ &self,
+ session_id: &acp::SessionId,
+ model: Arc<dyn LanguageModel>,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<()>> {
+ let agent = self.0.clone();
+ let result = agent.update(cx, |agent, cx| {
+ if let Some(thread) = agent.sessions.get(session_id) {
+ thread.update(cx, |thread, _| {
+ thread.selected_model = model;
+ });
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!("Session not found"))
+ }
+ });
+ Task::ready(result.unwrap_or_else(|e| Err(anyhow::anyhow!("Failed to update: {}", e))))
+ }
+
+ fn selected_model(
+ &self,
+ session_id: &acp::SessionId,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<Arc<dyn LanguageModel>>> {
+ let agent = self.0.clone();
+ let thread_result = agent
+ .read_with(cx, |agent, _| agent.sessions.get(session_id).cloned())
+ .ok()
+ .flatten()
+ .ok_or_else(|| anyhow::anyhow!("Session not found"));
+
+ match thread_result {
+ Ok(thread) => {
+ let selected = thread
+ .read_with(cx, |thread, _| thread.selected_model.clone())
+ .unwrap_or_else(|e| panic!("Failed to read thread: {}", e));
+ Task::ready(Ok(selected))
+ }
+ Err(e) => Task::ready(Err(e)),
+ }
+ }
+}
+
impl acp_thread::AgentConnection for AgentConnection {
fn new_thread(
self: Rc<Self>,
@@ -42,7 +103,13 @@ impl acp_thread::AgentConnection for AgentConnection {
// Create Thread and store in Agent
let (session_id, _thread) =
agent.update(cx, |agent, cx: &mut gpui::Context<Agent>| {
- let thread = cx.new(|_| Thread::new(agent.templates.clone()));
+ // Fetch default model
+ let default_model = LanguageModelRegistry::read_global(cx)
+ .available_models(cx)
+ .next()
+ .unwrap_or_else(|| panic!("No default model available"));
+
+ let thread = cx.new(|_| Thread::new(agent.templates.clone(), default_model));
let session_id = acp::SessionId(uuid::Uuid::new_v4().to_string().into());
agent.sessions.insert(session_id.clone(), thread.clone());
(session_id, thread)
@@ -50,7 +117,9 @@ impl acp_thread::AgentConnection for AgentConnection {
// Create AcpThread
let acp_thread = cx.update(|cx| {
- cx.new(|cx| acp_thread::AcpThread::new("agent2", self, project, session_id, cx))
+ cx.new(|cx| {
+ acp_thread::AcpThread::new("agent2", self.clone(), project, session_id, cx)
+ })
})?;
Ok(acp_thread)
@@ -65,11 +134,15 @@ impl acp_thread::AgentConnection for AgentConnection {
Task::ready(Ok(()))
}
+ fn model_selector(&self) -> Option<Rc<dyn ModelSelector>> {
+ Some(Rc::new(self.clone()) as Rc<dyn ModelSelector>)
+ }
+
fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
let session_id = params.session_id.clone();
let agent = self.0.clone();
- cx.spawn(|cx| async move {
+ cx.spawn(async move |cx| {
// Get thread
let thread: Entity<Thread> = agent
.read_with(cx, |agent, _| agent.sessions.get(&session_id).cloned())?
@@ -78,13 +151,12 @@ impl acp_thread::AgentConnection for AgentConnection {
// Convert prompt to message
let message = convert_prompt_to_message(params.prompt);
- // TODO: Get model from somewhere - for now use a placeholder
- log::warn!("Model selection not implemented - need to get from UI context");
+ // Get model using the ModelSelector capability (always available for agent2)
+ // Get the selected model from the thread directly
+ let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
// Send to thread
- // thread.update(&mut cx, |thread, cx| {
- // thread.send(model, message, cx)
- // })?;
+ thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
Ok(())
})
@@ -209,7 +209,6 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest {
cx.executor().allow_parking();
cx.update(settings::init);
let templates = Templates::new();
- let thread = cx.new(|_| Thread::new(templates));
let model = cx
.update(|cx| {
@@ -239,6 +238,8 @@ async fn setup(cx: &mut TestAppContext) -> ThreadTest {
})
.await;
+ let thread = cx.new(|_| Thread::new(templates, model.clone()));
+
ThreadTest { model, thread }
}
@@ -37,12 +37,13 @@ pub struct Thread {
system_prompts: Vec<Arc<dyn Prompt>>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
templates: Arc<Templates>,
+ pub selected_model: Arc<dyn LanguageModel>,
// project: Entity<Project>,
// action_log: Entity<ActionLog>,
}
impl Thread {
- pub fn new(templates: Arc<Templates>) -> Self {
+ pub fn new(templates: Arc<Templates>, default_model: Arc<dyn LanguageModel>) -> Self {
Self {
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
@@ -50,6 +51,7 @@ impl Thread {
running_turn: None,
tools: BTreeMap::default(),
templates,
+ selected_model: default_model,
}
}