1use crate::{AcpThread, AcpThreadMetadata};
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use futures::channel::mpsc::UnboundedReceiver;
6use gpui::{Entity, SharedString, Task};
7use project::Project;
8use serde::{Deserialize, Serialize};
9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
10use ui::{App, IconName};
11use uuid::Uuid;
12
13#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
14pub struct UserMessageId(Arc<str>);
15
16impl UserMessageId {
17 pub fn new() -> Self {
18 Self(Uuid::new_v4().to_string().into())
19 }
20}
21
22pub trait AgentConnection {
23 fn new_thread(
24 self: Rc<Self>,
25 project: Entity<Project>,
26 cwd: &Path,
27 cx: &mut App,
28 ) -> Task<Result<Entity<AcpThread>>>;
29
30 fn auth_methods(&self) -> &[acp::AuthMethod];
31
32 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
33
34 fn prompt(
35 &self,
36 user_message_id: Option<UserMessageId>,
37 params: acp::PromptRequest,
38 cx: &mut App,
39 ) -> Task<Result<acp::PromptResponse>>;
40
41 fn resume(
42 &self,
43 _session_id: &acp::SessionId,
44 _cx: &mut App,
45 ) -> Option<Rc<dyn AgentSessionResume>> {
46 None
47 }
48
49 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
50
51 fn session_editor(
52 &self,
53 _session_id: &acp::SessionId,
54 _cx: &mut App,
55 ) -> Option<Rc<dyn AgentSessionEditor>> {
56 None
57 }
58
59 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
60 ///
61 /// If the agent does not support model selection, returns [None].
62 /// This allows sharing the selector in UI components.
63 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
64 None
65 }
66
67 fn history(self: Rc<Self>) -> Option<Rc<dyn AgentHistory>> {
68 None
69 }
70
71 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
72}
73
74impl dyn AgentConnection {
75 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
76 self.into_any().downcast().ok()
77 }
78}
79
80pub trait AgentSessionEditor {
81 fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
82}
83
84pub trait AgentSessionResume {
85 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
86}
87
88pub trait AgentHistory {
89 fn list_threads(&self, cx: &mut App) -> Task<Result<Vec<AcpThreadMetadata>>>;
90 fn observe_history(&self, cx: &mut App) -> UnboundedReceiver<AcpThreadMetadata>;
91 fn load_thread(
92 self: Rc<Self>,
93 _project: Entity<Project>,
94 _cwd: &Path,
95 _session_id: acp::SessionId,
96 _cx: &mut App,
97 ) -> Task<Result<Entity<AcpThread>>>;
98}
99
100#[derive(Debug)]
101pub struct AuthRequired;
102
103impl Error for AuthRequired {}
104impl fmt::Display for AuthRequired {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 write!(f, "AuthRequired")
107 }
108}
109
110/// Trait for agents that support listing, selecting, and querying language models.
111///
112/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
113pub trait AgentModelSelector: 'static {
114 /// Lists all available language models for this agent.
115 ///
116 /// # Parameters
117 /// - `cx`: The GPUI app context for async operations and global access.
118 ///
119 /// # Returns
120 /// A task resolving to the list of models or an error (e.g., if no models are configured).
121 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
122
123 /// Selects a model for a specific session (thread).
124 ///
125 /// This sets the default model for future interactions in the session.
126 /// If the session doesn't exist or the model is invalid, it returns an error.
127 ///
128 /// # Parameters
129 /// - `session_id`: The ID of the session (thread) to apply the model to.
130 /// - `model`: The model to select (should be one from [list_models]).
131 /// - `cx`: The GPUI app context.
132 ///
133 /// # Returns
134 /// A task resolving to `Ok(())` on success or an error.
135 fn select_model(
136 &self,
137 session_id: acp::SessionId,
138 model_id: AgentModelId,
139 cx: &mut App,
140 ) -> Task<Result<()>>;
141
142 /// Retrieves the currently selected model for a specific session (thread).
143 ///
144 /// # Parameters
145 /// - `session_id`: The ID of the session (thread) to query.
146 /// - `cx`: The GPUI app context.
147 ///
148 /// # Returns
149 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
150 fn selected_model(
151 &self,
152 session_id: &acp::SessionId,
153 cx: &mut App,
154 ) -> Task<Result<AgentModelInfo>>;
155
156 /// Whenever the model list is updated the receiver will be notified.
157 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
158}
159
160#[derive(Debug, Clone, PartialEq, Eq, Hash)]
161pub struct AgentModelId(pub SharedString);
162
163impl std::ops::Deref for AgentModelId {
164 type Target = SharedString;
165
166 fn deref(&self) -> &Self::Target {
167 &self.0
168 }
169}
170
171impl fmt::Display for AgentModelId {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 self.0.fmt(f)
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub struct AgentModelInfo {
179 pub id: AgentModelId,
180 pub name: SharedString,
181 pub icon: Option<IconName>,
182}
183
184#[derive(Debug, Clone, PartialEq, Eq, Hash)]
185pub struct AgentModelGroupName(pub SharedString);
186
187#[derive(Debug, Clone)]
188pub enum AgentModelList {
189 Flat(Vec<AgentModelInfo>),
190 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
191}
192
193impl AgentModelList {
194 pub fn is_empty(&self) -> bool {
195 match self {
196 AgentModelList::Flat(models) => models.is_empty(),
197 AgentModelList::Grouped(groups) => groups.is_empty(),
198 }
199 }
200}
201
202#[cfg(feature = "test-support")]
203mod test_support {
204 use std::sync::Arc;
205
206 use collections::HashMap;
207 use futures::{channel::oneshot, future::try_join_all};
208 use gpui::{AppContext as _, WeakEntity};
209 use parking_lot::Mutex;
210
211 use super::*;
212
213 #[derive(Clone, Default)]
214 pub struct StubAgentConnection {
215 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
216 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
217 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
218 }
219
220 struct Session {
221 thread: WeakEntity<AcpThread>,
222 response_tx: Option<oneshot::Sender<acp::StopReason>>,
223 }
224
225 impl StubAgentConnection {
226 pub fn new() -> Self {
227 Self {
228 next_prompt_updates: Default::default(),
229 permission_requests: HashMap::default(),
230 sessions: Arc::default(),
231 }
232 }
233
234 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
235 *self.next_prompt_updates.lock() = updates;
236 }
237
238 pub fn with_permission_requests(
239 mut self,
240 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
241 ) -> Self {
242 self.permission_requests = permission_requests;
243 self
244 }
245
246 pub fn send_update(
247 &self,
248 session_id: acp::SessionId,
249 update: acp::SessionUpdate,
250 cx: &mut App,
251 ) {
252 assert!(
253 self.next_prompt_updates.lock().is_empty(),
254 "Use either send_update or set_next_prompt_updates"
255 );
256
257 self.sessions
258 .lock()
259 .get(&session_id)
260 .unwrap()
261 .thread
262 .update(cx, |thread, cx| {
263 thread.handle_session_update(update, cx).unwrap();
264 })
265 .unwrap();
266 }
267
268 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
269 self.sessions
270 .lock()
271 .get_mut(&session_id)
272 .unwrap()
273 .response_tx
274 .take()
275 .expect("No pending turn")
276 .send(stop_reason)
277 .unwrap();
278 }
279 }
280
281 impl AgentConnection for StubAgentConnection {
282 fn auth_methods(&self) -> &[acp::AuthMethod] {
283 &[]
284 }
285
286 fn new_thread(
287 self: Rc<Self>,
288 project: Entity<Project>,
289 _cwd: &Path,
290 cx: &mut gpui::App,
291 ) -> Task<gpui::Result<Entity<AcpThread>>> {
292 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
293 let thread =
294 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
295 self.sessions.lock().insert(
296 session_id,
297 Session {
298 thread: thread.downgrade(),
299 response_tx: None,
300 },
301 );
302 Task::ready(Ok(thread))
303 }
304
305 fn authenticate(
306 &self,
307 _method_id: acp::AuthMethodId,
308 _cx: &mut App,
309 ) -> Task<gpui::Result<()>> {
310 unimplemented!()
311 }
312
313 fn prompt(
314 &self,
315 _id: Option<UserMessageId>,
316 params: acp::PromptRequest,
317 cx: &mut App,
318 ) -> Task<gpui::Result<acp::PromptResponse>> {
319 let mut sessions = self.sessions.lock();
320 let Session {
321 thread,
322 response_tx,
323 } = sessions.get_mut(¶ms.session_id).unwrap();
324 let mut tasks = vec![];
325 if self.next_prompt_updates.lock().is_empty() {
326 let (tx, rx) = oneshot::channel();
327 response_tx.replace(tx);
328 cx.spawn(async move |_| {
329 let stop_reason = rx.await?;
330 Ok(acp::PromptResponse { stop_reason })
331 })
332 } else {
333 for update in self.next_prompt_updates.lock().drain(..) {
334 let thread = thread.clone();
335 let update = update.clone();
336 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
337 &update
338 && let Some(options) = self.permission_requests.get(&tool_call.id)
339 {
340 Some((tool_call.clone(), options.clone()))
341 } else {
342 None
343 };
344 let task = cx.spawn(async move |cx| {
345 if let Some((tool_call, options)) = permission_request {
346 let permission = thread.update(cx, |thread, cx| {
347 thread.request_tool_call_authorization(
348 tool_call.clone().into(),
349 options.clone(),
350 cx,
351 )
352 })?;
353 permission?.await?;
354 }
355 thread.update(cx, |thread, cx| {
356 thread.handle_session_update(update.clone(), cx).unwrap();
357 })?;
358 anyhow::Ok(())
359 });
360 tasks.push(task);
361 }
362
363 cx.spawn(async move |_| {
364 try_join_all(tasks).await?;
365 Ok(acp::PromptResponse {
366 stop_reason: acp::StopReason::EndTurn,
367 })
368 })
369 }
370 }
371
372 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
373 if let Some(end_turn_tx) = self
374 .sessions
375 .lock()
376 .get_mut(session_id)
377 .unwrap()
378 .response_tx
379 .take()
380 {
381 end_turn_tx.send(acp::StopReason::Canceled).unwrap();
382 }
383 }
384
385 fn session_editor(
386 &self,
387 _session_id: &agent_client_protocol::SessionId,
388 _cx: &mut App,
389 ) -> Option<Rc<dyn AgentSessionEditor>> {
390 Some(Rc::new(StubAgentSessionEditor))
391 }
392
393 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
394 self
395 }
396 }
397
398 struct StubAgentSessionEditor;
399
400 impl AgentSessionEditor for StubAgentSessionEditor {
401 fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
402 Task::ready(Ok(()))
403 }
404 }
405}
406
407#[cfg(feature = "test-support")]
408pub use test_support::*;