1use crate::{AcpThread, AcpThreadMetadata};
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use project::Project;
7use serde::{Deserialize, Serialize};
8use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
9use ui::{App, IconName};
10use uuid::Uuid;
11
12#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
13pub struct UserMessageId(Arc<str>);
14
15impl UserMessageId {
16 pub fn new() -> Self {
17 Self(Uuid::new_v4().to_string().into())
18 }
19}
20
21pub trait AgentConnection {
22 fn new_thread(
23 self: Rc<Self>,
24 project: Entity<Project>,
25 cwd: &Path,
26 cx: &mut App,
27 ) -> Task<Result<Entity<AcpThread>>>;
28
29 // todo!(expose a history trait, and include list_threads and load_thread)
30 // todo!(write a test)
31 fn list_threads(
32 &self,
33 _cx: &mut App,
34 ) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
35 return None;
36 }
37
38 fn load_thread(
39 self: Rc<Self>,
40 _project: Entity<Project>,
41 _cwd: &Path,
42 _session_id: acp::SessionId,
43 _cx: &mut App,
44 ) -> Task<Result<Entity<AcpThread>>> {
45 Task::ready(Err(anyhow::anyhow!("load thread not implemented")))
46 }
47
48 fn auth_methods(&self) -> &[acp::AuthMethod];
49
50 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
51
52 fn prompt(
53 &self,
54 user_message_id: Option<UserMessageId>,
55 params: acp::PromptRequest,
56 cx: &mut App,
57 ) -> Task<Result<acp::PromptResponse>>;
58
59 fn resume(
60 &self,
61 _session_id: &acp::SessionId,
62 _cx: &mut App,
63 ) -> Option<Rc<dyn AgentSessionResume>> {
64 None
65 }
66
67 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
68
69 fn session_editor(
70 &self,
71 _session_id: &acp::SessionId,
72 _cx: &mut App,
73 ) -> Option<Rc<dyn AgentSessionEditor>> {
74 None
75 }
76
77 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
78 ///
79 /// If the agent does not support model selection, returns [None].
80 /// This allows sharing the selector in UI components.
81 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
82 None
83 }
84
85 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
86}
87
88impl dyn AgentConnection {
89 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
90 self.into_any().downcast().ok()
91 }
92}
93
94pub trait AgentSessionEditor {
95 fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
96}
97
98pub trait AgentSessionResume {
99 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
100}
101
102#[derive(Debug)]
103pub struct AuthRequired;
104
105impl Error for AuthRequired {}
106impl fmt::Display for AuthRequired {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(f, "AuthRequired")
109 }
110}
111
112/// Trait for agents that support listing, selecting, and querying language models.
113///
114/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
115pub trait AgentModelSelector: 'static {
116 /// Lists all available language models for this agent.
117 ///
118 /// # Parameters
119 /// - `cx`: The GPUI app context for async operations and global access.
120 ///
121 /// # Returns
122 /// A task resolving to the list of models or an error (e.g., if no models are configured).
123 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
124
125 /// Selects a model for a specific session (thread).
126 ///
127 /// This sets the default model for future interactions in the session.
128 /// If the session doesn't exist or the model is invalid, it returns an error.
129 ///
130 /// # Parameters
131 /// - `session_id`: The ID of the session (thread) to apply the model to.
132 /// - `model`: The model to select (should be one from [list_models]).
133 /// - `cx`: The GPUI app context.
134 ///
135 /// # Returns
136 /// A task resolving to `Ok(())` on success or an error.
137 fn select_model(
138 &self,
139 session_id: acp::SessionId,
140 model_id: AgentModelId,
141 cx: &mut App,
142 ) -> Task<Result<()>>;
143
144 /// Retrieves the currently selected model for a specific session (thread).
145 ///
146 /// # Parameters
147 /// - `session_id`: The ID of the session (thread) to query.
148 /// - `cx`: The GPUI app context.
149 ///
150 /// # Returns
151 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
152 fn selected_model(
153 &self,
154 session_id: &acp::SessionId,
155 cx: &mut App,
156 ) -> Task<Result<AgentModelInfo>>;
157
158 /// Whenever the model list is updated the receiver will be notified.
159 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
160}
161
162#[derive(Debug, Clone, PartialEq, Eq, Hash)]
163pub struct AgentModelId(pub SharedString);
164
165impl std::ops::Deref for AgentModelId {
166 type Target = SharedString;
167
168 fn deref(&self) -> &Self::Target {
169 &self.0
170 }
171}
172
173impl fmt::Display for AgentModelId {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 self.0.fmt(f)
176 }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
180pub struct AgentModelInfo {
181 pub id: AgentModelId,
182 pub name: SharedString,
183 pub icon: Option<IconName>,
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Hash)]
187pub struct AgentModelGroupName(pub SharedString);
188
189#[derive(Debug, Clone)]
190pub enum AgentModelList {
191 Flat(Vec<AgentModelInfo>),
192 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
193}
194
195impl AgentModelList {
196 pub fn is_empty(&self) -> bool {
197 match self {
198 AgentModelList::Flat(models) => models.is_empty(),
199 AgentModelList::Grouped(groups) => groups.is_empty(),
200 }
201 }
202}
203
204#[cfg(feature = "test-support")]
205mod test_support {
206 use std::sync::Arc;
207
208 use collections::HashMap;
209 use futures::{channel::oneshot, future::try_join_all};
210 use gpui::{AppContext as _, WeakEntity};
211 use parking_lot::Mutex;
212
213 use super::*;
214
215 #[derive(Clone, Default)]
216 pub struct StubAgentConnection {
217 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
218 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
219 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
220 }
221
222 struct Session {
223 thread: WeakEntity<AcpThread>,
224 response_tx: Option<oneshot::Sender<acp::StopReason>>,
225 }
226
227 impl StubAgentConnection {
228 pub fn new() -> Self {
229 Self {
230 next_prompt_updates: Default::default(),
231 permission_requests: HashMap::default(),
232 sessions: Arc::default(),
233 }
234 }
235
236 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
237 *self.next_prompt_updates.lock() = updates;
238 }
239
240 pub fn with_permission_requests(
241 mut self,
242 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
243 ) -> Self {
244 self.permission_requests = permission_requests;
245 self
246 }
247
248 pub fn send_update(
249 &self,
250 session_id: acp::SessionId,
251 update: acp::SessionUpdate,
252 cx: &mut App,
253 ) {
254 assert!(
255 self.next_prompt_updates.lock().is_empty(),
256 "Use either send_update or set_next_prompt_updates"
257 );
258
259 self.sessions
260 .lock()
261 .get(&session_id)
262 .unwrap()
263 .thread
264 .update(cx, |thread, cx| {
265 thread.handle_session_update(update, cx).unwrap();
266 })
267 .unwrap();
268 }
269
270 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
271 self.sessions
272 .lock()
273 .get_mut(&session_id)
274 .unwrap()
275 .response_tx
276 .take()
277 .expect("No pending turn")
278 .send(stop_reason)
279 .unwrap();
280 }
281 }
282
283 impl AgentConnection for StubAgentConnection {
284 fn auth_methods(&self) -> &[acp::AuthMethod] {
285 &[]
286 }
287
288 fn new_thread(
289 self: Rc<Self>,
290 project: Entity<Project>,
291 _cwd: &Path,
292 cx: &mut gpui::App,
293 ) -> Task<gpui::Result<Entity<AcpThread>>> {
294 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
295 let thread =
296 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
297 self.sessions.lock().insert(
298 session_id,
299 Session {
300 thread: thread.downgrade(),
301 response_tx: None,
302 },
303 );
304 Task::ready(Ok(thread))
305 }
306
307 fn authenticate(
308 &self,
309 _method_id: acp::AuthMethodId,
310 _cx: &mut App,
311 ) -> Task<gpui::Result<()>> {
312 unimplemented!()
313 }
314
315 fn prompt(
316 &self,
317 _id: Option<UserMessageId>,
318 params: acp::PromptRequest,
319 cx: &mut App,
320 ) -> Task<gpui::Result<acp::PromptResponse>> {
321 let mut sessions = self.sessions.lock();
322 let Session {
323 thread,
324 response_tx,
325 } = sessions.get_mut(¶ms.session_id).unwrap();
326 let mut tasks = vec![];
327 if self.next_prompt_updates.lock().is_empty() {
328 let (tx, rx) = oneshot::channel();
329 response_tx.replace(tx);
330 cx.spawn(async move |_| {
331 let stop_reason = rx.await?;
332 Ok(acp::PromptResponse { stop_reason })
333 })
334 } else {
335 for update in self.next_prompt_updates.lock().drain(..) {
336 let thread = thread.clone();
337 let update = update.clone();
338 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
339 &update
340 && let Some(options) = self.permission_requests.get(&tool_call.id)
341 {
342 Some((tool_call.clone(), options.clone()))
343 } else {
344 None
345 };
346 let task = cx.spawn(async move |cx| {
347 if let Some((tool_call, options)) = permission_request {
348 let permission = thread.update(cx, |thread, cx| {
349 thread.request_tool_call_authorization(
350 tool_call.clone().into(),
351 options.clone(),
352 cx,
353 )
354 })?;
355 permission?.await?;
356 }
357 thread.update(cx, |thread, cx| {
358 thread.handle_session_update(update.clone(), cx).unwrap();
359 })?;
360 anyhow::Ok(())
361 });
362 tasks.push(task);
363 }
364
365 cx.spawn(async move |_| {
366 try_join_all(tasks).await?;
367 Ok(acp::PromptResponse {
368 stop_reason: acp::StopReason::EndTurn,
369 })
370 })
371 }
372 }
373
374 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
375 if let Some(end_turn_tx) = self
376 .sessions
377 .lock()
378 .get_mut(session_id)
379 .unwrap()
380 .response_tx
381 .take()
382 {
383 end_turn_tx.send(acp::StopReason::Canceled).unwrap();
384 }
385 }
386
387 fn session_editor(
388 &self,
389 _session_id: &agent_client_protocol::SessionId,
390 _cx: &mut App,
391 ) -> Option<Rc<dyn AgentSessionEditor>> {
392 Some(Rc::new(StubAgentSessionEditor))
393 }
394
395 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
396 self
397 }
398 }
399
400 struct StubAgentSessionEditor;
401
402 impl AgentSessionEditor for StubAgentSessionEditor {
403 fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
404 Task::ready(Ok(()))
405 }
406 }
407}
408
409#[cfg(feature = "test-support")]
410pub use test_support::*;