1use crate::{AcpThread, AcpThreadMetadata};
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use futures::channel::mpsc::UnboundedReceiver;
6use gpui::{Entity, SharedString, Task};
7use project::Project;
8use serde::{Deserialize, Serialize};
9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
10use ui::{App, IconName};
11use uuid::Uuid;
12
13#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
14pub struct UserMessageId(Arc<str>);
15
16impl UserMessageId {
17 pub fn new() -> Self {
18 Self(Uuid::new_v4().to_string().into())
19 }
20}
21
22pub trait AgentConnection {
23 fn new_thread(
24 self: Rc<Self>,
25 project: Entity<Project>,
26 cwd: &Path,
27 cx: &mut App,
28 ) -> Task<Result<Entity<AcpThread>>>;
29
30 fn list_threads(&self, _cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
31 return None;
32 }
33
34 fn load_thread(
35 self: Rc<Self>,
36 _project: Entity<Project>,
37 _cwd: &Path,
38 _session_id: acp::SessionId,
39 _cx: &mut App,
40 ) -> Task<Result<Entity<AcpThread>>> {
41 Task::ready(Err(anyhow::anyhow!("load thread not implemented")))
42 }
43
44 fn auth_methods(&self) -> &[acp::AuthMethod];
45
46 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
47
48 fn prompt(
49 &self,
50 user_message_id: Option<UserMessageId>,
51 params: acp::PromptRequest,
52 cx: &mut App,
53 ) -> Task<Result<acp::PromptResponse>>;
54
55 fn resume(
56 &self,
57 _session_id: &acp::SessionId,
58 _cx: &mut App,
59 ) -> Option<Rc<dyn AgentSessionResume>> {
60 None
61 }
62
63 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
64
65 fn session_editor(
66 &self,
67 _session_id: &acp::SessionId,
68 _cx: &mut App,
69 ) -> Option<Rc<dyn AgentSessionEditor>> {
70 None
71 }
72
73 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
74 ///
75 /// If the agent does not support model selection, returns [None].
76 /// This allows sharing the selector in UI components.
77 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
78 None
79 }
80
81 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
82}
83
84impl dyn AgentConnection {
85 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
86 self.into_any().downcast().ok()
87 }
88}
89
90pub trait AgentSessionEditor {
91 fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
92}
93
94pub trait AgentSessionResume {
95 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
96}
97
98#[derive(Debug)]
99pub struct AuthRequired;
100
101impl Error for AuthRequired {}
102impl fmt::Display for AuthRequired {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 write!(f, "AuthRequired")
105 }
106}
107
108/// Trait for agents that support listing, selecting, and querying language models.
109///
110/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
111pub trait AgentModelSelector: 'static {
112 /// Lists all available language models for this agent.
113 ///
114 /// # Parameters
115 /// - `cx`: The GPUI app context for async operations and global access.
116 ///
117 /// # Returns
118 /// A task resolving to the list of models or an error (e.g., if no models are configured).
119 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
120
121 /// Selects a model for a specific session (thread).
122 ///
123 /// This sets the default model for future interactions in the session.
124 /// If the session doesn't exist or the model is invalid, it returns an error.
125 ///
126 /// # Parameters
127 /// - `session_id`: The ID of the session (thread) to apply the model to.
128 /// - `model`: The model to select (should be one from [list_models]).
129 /// - `cx`: The GPUI app context.
130 ///
131 /// # Returns
132 /// A task resolving to `Ok(())` on success or an error.
133 fn select_model(
134 &self,
135 session_id: acp::SessionId,
136 model_id: AgentModelId,
137 cx: &mut App,
138 ) -> Task<Result<()>>;
139
140 /// Retrieves the currently selected model for a specific session (thread).
141 ///
142 /// # Parameters
143 /// - `session_id`: The ID of the session (thread) to query.
144 /// - `cx`: The GPUI app context.
145 ///
146 /// # Returns
147 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
148 fn selected_model(
149 &self,
150 session_id: &acp::SessionId,
151 cx: &mut App,
152 ) -> Task<Result<AgentModelInfo>>;
153
154 /// Whenever the model list is updated the receiver will be notified.
155 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
156}
157
158#[derive(Debug, Clone, PartialEq, Eq, Hash)]
159pub struct AgentModelId(pub SharedString);
160
161impl std::ops::Deref for AgentModelId {
162 type Target = SharedString;
163
164 fn deref(&self) -> &Self::Target {
165 &self.0
166 }
167}
168
169impl fmt::Display for AgentModelId {
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 self.0.fmt(f)
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub struct AgentModelInfo {
177 pub id: AgentModelId,
178 pub name: SharedString,
179 pub icon: Option<IconName>,
180}
181
182#[derive(Debug, Clone, PartialEq, Eq, Hash)]
183pub struct AgentModelGroupName(pub SharedString);
184
185#[derive(Debug, Clone)]
186pub enum AgentModelList {
187 Flat(Vec<AgentModelInfo>),
188 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
189}
190
191impl AgentModelList {
192 pub fn is_empty(&self) -> bool {
193 match self {
194 AgentModelList::Flat(models) => models.is_empty(),
195 AgentModelList::Grouped(groups) => groups.is_empty(),
196 }
197 }
198}
199
200#[cfg(feature = "test-support")]
201mod test_support {
202 use std::sync::Arc;
203
204 use collections::HashMap;
205 use futures::future::try_join_all;
206 use gpui::{AppContext as _, WeakEntity};
207 use parking_lot::Mutex;
208
209 use super::*;
210
211 #[derive(Clone, Default)]
212 pub struct StubAgentConnection {
213 sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
214 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
215 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
216 }
217
218 impl StubAgentConnection {
219 pub fn new() -> Self {
220 Self {
221 next_prompt_updates: Default::default(),
222 permission_requests: HashMap::default(),
223 sessions: Arc::default(),
224 }
225 }
226
227 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
228 *self.next_prompt_updates.lock() = updates;
229 }
230
231 pub fn with_permission_requests(
232 mut self,
233 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
234 ) -> Self {
235 self.permission_requests = permission_requests;
236 self
237 }
238
239 pub fn send_update(
240 &self,
241 session_id: acp::SessionId,
242 update: acp::SessionUpdate,
243 cx: &mut App,
244 ) {
245 self.sessions
246 .lock()
247 .get(&session_id)
248 .unwrap()
249 .update(cx, |thread, cx| {
250 thread.handle_session_update(update.clone(), cx).unwrap();
251 })
252 .unwrap();
253 }
254 }
255
256 impl AgentConnection for StubAgentConnection {
257 fn auth_methods(&self) -> &[acp::AuthMethod] {
258 &[]
259 }
260
261 fn new_thread(
262 self: Rc<Self>,
263 project: Entity<Project>,
264 _cwd: &Path,
265 cx: &mut gpui::App,
266 ) -> Task<gpui::Result<Entity<AcpThread>>> {
267 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
268 let thread =
269 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
270 self.sessions.lock().insert(session_id, thread.downgrade());
271 Task::ready(Ok(thread))
272 }
273
274 fn authenticate(
275 &self,
276 _method_id: acp::AuthMethodId,
277 _cx: &mut App,
278 ) -> Task<gpui::Result<()>> {
279 unimplemented!()
280 }
281
282 fn prompt(
283 &self,
284 _id: Option<UserMessageId>,
285 params: acp::PromptRequest,
286 cx: &mut App,
287 ) -> Task<gpui::Result<acp::PromptResponse>> {
288 let sessions = self.sessions.lock();
289 let thread = sessions.get(¶ms.session_id).unwrap();
290 let mut tasks = vec![];
291 for update in self.next_prompt_updates.lock().drain(..) {
292 let thread = thread.clone();
293 let update = update.clone();
294 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
295 && let Some(options) = self.permission_requests.get(&tool_call.id)
296 {
297 Some((tool_call.clone(), options.clone()))
298 } else {
299 None
300 };
301 let task = cx.spawn(async move |cx| {
302 if let Some((tool_call, options)) = permission_request {
303 let permission = thread.update(cx, |thread, cx| {
304 thread.request_tool_call_authorization(
305 tool_call.clone().into(),
306 options.clone(),
307 cx,
308 )
309 })?;
310 permission?.await?;
311 }
312 thread.update(cx, |thread, cx| {
313 thread.handle_session_update(update.clone(), cx).unwrap();
314 })?;
315 anyhow::Ok(())
316 });
317 tasks.push(task);
318 }
319 cx.spawn(async move |_| {
320 try_join_all(tasks).await?;
321 Ok(acp::PromptResponse {
322 stop_reason: acp::StopReason::EndTurn,
323 })
324 })
325 }
326
327 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
328 unimplemented!()
329 }
330
331 fn session_editor(
332 &self,
333 _session_id: &agent_client_protocol::SessionId,
334 _cx: &mut App,
335 ) -> Option<Rc<dyn AgentSessionEditor>> {
336 Some(Rc::new(StubAgentSessionEditor))
337 }
338
339 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
340 self
341 }
342 }
343
344 struct StubAgentSessionEditor;
345
346 impl AgentSessionEditor for StubAgentSessionEditor {
347 fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
348 Task::ready(Ok(()))
349 }
350 }
351}
352
353#[cfg(feature = "test-support")]
354pub use test_support::*;