1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use project::Project;
7use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
8use ui::{App, IconName};
9use uuid::Uuid;
10
11#[derive(Clone, Debug, Eq, PartialEq)]
12pub struct UserMessageId(Arc<str>);
13
14impl UserMessageId {
15 pub fn new() -> Self {
16 Self(Uuid::new_v4().to_string().into())
17 }
18}
19
20pub trait AgentConnection {
21 fn new_thread(
22 self: Rc<Self>,
23 project: Entity<Project>,
24 cwd: &Path,
25 cx: &mut App,
26 ) -> Task<Result<Entity<AcpThread>>>;
27
28 fn auth_methods(&self) -> &[acp::AuthMethod];
29
30 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
31
32 fn prompt(
33 &self,
34 user_message_id: Option<UserMessageId>,
35 params: acp::PromptRequest,
36 cx: &mut App,
37 ) -> Task<Result<acp::PromptResponse>>;
38
39 fn resume(
40 &self,
41 _session_id: &acp::SessionId,
42 _cx: &mut App,
43 ) -> Option<Rc<dyn AgentSessionResume>> {
44 None
45 }
46
47 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
48
49 fn session_editor(
50 &self,
51 _session_id: &acp::SessionId,
52 _cx: &mut App,
53 ) -> Option<Rc<dyn AgentSessionEditor>> {
54 None
55 }
56
57 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
58 ///
59 /// If the agent does not support model selection, returns [None].
60 /// This allows sharing the selector in UI components.
61 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
62 None
63 }
64
65 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
66}
67
68impl dyn AgentConnection {
69 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
70 self.into_any().downcast().ok()
71 }
72}
73
74pub trait AgentSessionEditor {
75 fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
76}
77
78pub trait AgentSessionResume {
79 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
80}
81
82#[derive(Debug)]
83pub struct AuthRequired;
84
85impl Error for AuthRequired {}
86impl fmt::Display for AuthRequired {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 write!(f, "AuthRequired")
89 }
90}
91
92/// Trait for agents that support listing, selecting, and querying language models.
93///
94/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
95pub trait AgentModelSelector: 'static {
96 /// Lists all available language models for this agent.
97 ///
98 /// # Parameters
99 /// - `cx`: The GPUI app context for async operations and global access.
100 ///
101 /// # Returns
102 /// A task resolving to the list of models or an error (e.g., if no models are configured).
103 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
104
105 /// Selects a model for a specific session (thread).
106 ///
107 /// This sets the default model for future interactions in the session.
108 /// If the session doesn't exist or the model is invalid, it returns an error.
109 ///
110 /// # Parameters
111 /// - `session_id`: The ID of the session (thread) to apply the model to.
112 /// - `model`: The model to select (should be one from [list_models]).
113 /// - `cx`: The GPUI app context.
114 ///
115 /// # Returns
116 /// A task resolving to `Ok(())` on success or an error.
117 fn select_model(
118 &self,
119 session_id: acp::SessionId,
120 model_id: AgentModelId,
121 cx: &mut App,
122 ) -> Task<Result<()>>;
123
124 /// Retrieves the currently selected model for a specific session (thread).
125 ///
126 /// # Parameters
127 /// - `session_id`: The ID of the session (thread) to query.
128 /// - `cx`: The GPUI app context.
129 ///
130 /// # Returns
131 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
132 fn selected_model(
133 &self,
134 session_id: &acp::SessionId,
135 cx: &mut App,
136 ) -> Task<Result<AgentModelInfo>>;
137
138 /// Whenever the model list is updated the receiver will be notified.
139 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Hash)]
143pub struct AgentModelId(pub SharedString);
144
145impl std::ops::Deref for AgentModelId {
146 type Target = SharedString;
147
148 fn deref(&self) -> &Self::Target {
149 &self.0
150 }
151}
152
153impl fmt::Display for AgentModelId {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 self.0.fmt(f)
156 }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub struct AgentModelInfo {
161 pub id: AgentModelId,
162 pub name: SharedString,
163 pub icon: Option<IconName>,
164}
165
166#[derive(Debug, Clone, PartialEq, Eq, Hash)]
167pub struct AgentModelGroupName(pub SharedString);
168
169#[derive(Debug, Clone)]
170pub enum AgentModelList {
171 Flat(Vec<AgentModelInfo>),
172 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
173}
174
175impl AgentModelList {
176 pub fn is_empty(&self) -> bool {
177 match self {
178 AgentModelList::Flat(models) => models.is_empty(),
179 AgentModelList::Grouped(groups) => groups.is_empty(),
180 }
181 }
182}
183
184#[cfg(feature = "test-support")]
185mod test_support {
186 use std::sync::Arc;
187
188 use collections::HashMap;
189 use futures::{channel::oneshot, future::try_join_all};
190 use gpui::{AppContext as _, WeakEntity};
191 use parking_lot::Mutex;
192
193 use super::*;
194
195 #[derive(Clone, Default)]
196 pub struct StubAgentConnection {
197 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
198 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
199 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
200 }
201
202 struct Session {
203 thread: WeakEntity<AcpThread>,
204 response_tx: Option<oneshot::Sender<()>>,
205 }
206
207 impl StubAgentConnection {
208 pub fn new() -> Self {
209 Self {
210 next_prompt_updates: Default::default(),
211 permission_requests: HashMap::default(),
212 sessions: Arc::default(),
213 }
214 }
215
216 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
217 *self.next_prompt_updates.lock() = updates;
218 }
219
220 pub fn with_permission_requests(
221 mut self,
222 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
223 ) -> Self {
224 self.permission_requests = permission_requests;
225 self
226 }
227
228 pub fn send_update(
229 &self,
230 session_id: acp::SessionId,
231 update: acp::SessionUpdate,
232 cx: &mut App,
233 ) {
234 assert!(
235 self.next_prompt_updates.lock().is_empty(),
236 "Use either send_update or set_next_prompt_updates"
237 );
238
239 self.sessions
240 .lock()
241 .get(&session_id)
242 .unwrap()
243 .thread
244 .update(cx, |thread, cx| {
245 thread.handle_session_update(update.clone(), cx).unwrap();
246 })
247 .unwrap();
248 }
249
250 pub fn end_turn(&self, session_id: acp::SessionId) {
251 self.sessions
252 .lock()
253 .get_mut(&session_id)
254 .unwrap()
255 .response_tx
256 .take()
257 .expect("No pending turn")
258 .send(())
259 .unwrap();
260 }
261 }
262
263 impl AgentConnection for StubAgentConnection {
264 fn auth_methods(&self) -> &[acp::AuthMethod] {
265 &[]
266 }
267
268 fn new_thread(
269 self: Rc<Self>,
270 project: Entity<Project>,
271 _cwd: &Path,
272 cx: &mut gpui::App,
273 ) -> Task<gpui::Result<Entity<AcpThread>>> {
274 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
275 let thread =
276 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
277 self.sessions.lock().insert(
278 session_id,
279 Session {
280 thread: thread.downgrade(),
281 response_tx: None,
282 },
283 );
284 Task::ready(Ok(thread))
285 }
286
287 fn authenticate(
288 &self,
289 _method_id: acp::AuthMethodId,
290 _cx: &mut App,
291 ) -> Task<gpui::Result<()>> {
292 unimplemented!()
293 }
294
295 fn prompt(
296 &self,
297 _id: Option<UserMessageId>,
298 params: acp::PromptRequest,
299 cx: &mut App,
300 ) -> Task<gpui::Result<acp::PromptResponse>> {
301 let mut sessions = self.sessions.lock();
302 let Session {
303 thread,
304 response_tx,
305 } = sessions.get_mut(¶ms.session_id).unwrap();
306 let mut tasks = vec![];
307 if self.next_prompt_updates.lock().is_empty() {
308 let (tx, rx) = oneshot::channel();
309 response_tx.replace(tx);
310 cx.spawn(async move |_| {
311 rx.await?;
312 Ok(acp::PromptResponse {
313 stop_reason: acp::StopReason::EndTurn,
314 })
315 })
316 } else {
317 for update in self.next_prompt_updates.lock().drain(..) {
318 let thread = thread.clone();
319 let update = update.clone();
320 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
321 &update
322 && let Some(options) = self.permission_requests.get(&tool_call.id)
323 {
324 Some((tool_call.clone(), options.clone()))
325 } else {
326 None
327 };
328 let task = cx.spawn(async move |cx| {
329 if let Some((tool_call, options)) = permission_request {
330 let permission = thread.update(cx, |thread, cx| {
331 thread.request_tool_call_authorization(
332 tool_call.clone().into(),
333 options.clone(),
334 cx,
335 )
336 })?;
337 permission?.await?;
338 }
339 thread.update(cx, |thread, cx| {
340 thread.handle_session_update(update.clone(), cx).unwrap();
341 })?;
342 anyhow::Ok(())
343 });
344 tasks.push(task);
345 }
346
347 cx.spawn(async move |_| {
348 try_join_all(tasks).await?;
349 Ok(acp::PromptResponse {
350 stop_reason: acp::StopReason::EndTurn,
351 })
352 })
353 }
354 }
355
356 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
357 unimplemented!()
358 }
359
360 fn session_editor(
361 &self,
362 _session_id: &agent_client_protocol::SessionId,
363 _cx: &mut App,
364 ) -> Option<Rc<dyn AgentSessionEditor>> {
365 Some(Rc::new(StubAgentSessionEditor))
366 }
367
368 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
369 self
370 }
371 }
372
373 struct StubAgentSessionEditor;
374
375 impl AgentSessionEditor for StubAgentSessionEditor {
376 fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
377 Task::ready(Ok(()))
378 }
379 }
380}
381
382#[cfg(feature = "test-support")]
383pub use test_support::*;