1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use language_model::LanguageModelProviderId;
7use project::Project;
8use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
9use ui::{App, IconName};
10use uuid::Uuid;
11
12#[derive(Clone, Debug, Eq, PartialEq)]
13pub struct UserMessageId(Arc<str>);
14
15impl UserMessageId {
16 pub fn new() -> Self {
17 Self(Uuid::new_v4().to_string().into())
18 }
19}
20
21pub trait AgentConnection {
22 fn new_thread(
23 self: Rc<Self>,
24 project: Entity<Project>,
25 cwd: &Path,
26 cx: &mut App,
27 ) -> Task<Result<Entity<AcpThread>>>;
28
29 fn auth_methods(&self) -> &[acp::AuthMethod];
30
31 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
32
33 fn prompt(
34 &self,
35 user_message_id: Option<UserMessageId>,
36 params: acp::PromptRequest,
37 cx: &mut App,
38 ) -> Task<Result<acp::PromptResponse>>;
39
40 fn resume(
41 &self,
42 _session_id: &acp::SessionId,
43 _cx: &mut App,
44 ) -> Option<Rc<dyn AgentSessionResume>> {
45 None
46 }
47
48 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
49
50 fn session_editor(
51 &self,
52 _session_id: &acp::SessionId,
53 _cx: &mut App,
54 ) -> Option<Rc<dyn AgentSessionEditor>> {
55 None
56 }
57
58 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
59 ///
60 /// If the agent does not support model selection, returns [None].
61 /// This allows sharing the selector in UI components.
62 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
63 None
64 }
65
66 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
67}
68
69impl dyn AgentConnection {
70 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
71 self.into_any().downcast().ok()
72 }
73}
74
75pub trait AgentSessionEditor {
76 fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
77}
78
79pub trait AgentSessionResume {
80 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
81}
82
83#[derive(Debug)]
84pub struct AuthRequired {
85 pub description: Option<String>,
86 pub provider_id: Option<LanguageModelProviderId>,
87}
88
89impl AuthRequired {
90 pub fn new() -> Self {
91 Self {
92 description: None,
93 provider_id: None,
94 }
95 }
96
97 pub fn with_description(mut self, description: String) -> Self {
98 self.description = Some(description);
99 self
100 }
101
102 pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
103 self.provider_id = Some(provider_id);
104 self
105 }
106}
107
108impl Error for AuthRequired {}
109impl fmt::Display for AuthRequired {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 write!(f, "Authentication required")
112 }
113}
114
115/// Trait for agents that support listing, selecting, and querying language models.
116///
117/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
118pub trait AgentModelSelector: 'static {
119 /// Lists all available language models for this agent.
120 ///
121 /// # Parameters
122 /// - `cx`: The GPUI app context for async operations and global access.
123 ///
124 /// # Returns
125 /// A task resolving to the list of models or an error (e.g., if no models are configured).
126 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
127
128 /// Selects a model for a specific session (thread).
129 ///
130 /// This sets the default model for future interactions in the session.
131 /// If the session doesn't exist or the model is invalid, it returns an error.
132 ///
133 /// # Parameters
134 /// - `session_id`: The ID of the session (thread) to apply the model to.
135 /// - `model`: The model to select (should be one from [list_models]).
136 /// - `cx`: The GPUI app context.
137 ///
138 /// # Returns
139 /// A task resolving to `Ok(())` on success or an error.
140 fn select_model(
141 &self,
142 session_id: acp::SessionId,
143 model_id: AgentModelId,
144 cx: &mut App,
145 ) -> Task<Result<()>>;
146
147 /// Retrieves the currently selected model for a specific session (thread).
148 ///
149 /// # Parameters
150 /// - `session_id`: The ID of the session (thread) to query.
151 /// - `cx`: The GPUI app context.
152 ///
153 /// # Returns
154 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
155 fn selected_model(
156 &self,
157 session_id: &acp::SessionId,
158 cx: &mut App,
159 ) -> Task<Result<AgentModelInfo>>;
160
161 /// Whenever the model list is updated the receiver will be notified.
162 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
163}
164
165#[derive(Debug, Clone, PartialEq, Eq, Hash)]
166pub struct AgentModelId(pub SharedString);
167
168impl std::ops::Deref for AgentModelId {
169 type Target = SharedString;
170
171 fn deref(&self) -> &Self::Target {
172 &self.0
173 }
174}
175
176impl fmt::Display for AgentModelId {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 self.0.fmt(f)
179 }
180}
181
182#[derive(Debug, Clone, PartialEq, Eq)]
183pub struct AgentModelInfo {
184 pub id: AgentModelId,
185 pub name: SharedString,
186 pub icon: Option<IconName>,
187}
188
189#[derive(Debug, Clone, PartialEq, Eq, Hash)]
190pub struct AgentModelGroupName(pub SharedString);
191
192#[derive(Debug, Clone)]
193pub enum AgentModelList {
194 Flat(Vec<AgentModelInfo>),
195 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
196}
197
198impl AgentModelList {
199 pub fn is_empty(&self) -> bool {
200 match self {
201 AgentModelList::Flat(models) => models.is_empty(),
202 AgentModelList::Grouped(groups) => groups.is_empty(),
203 }
204 }
205}
206
207#[cfg(feature = "test-support")]
208mod test_support {
209 use std::sync::Arc;
210
211 use collections::HashMap;
212 use futures::{channel::oneshot, future::try_join_all};
213 use gpui::{AppContext as _, WeakEntity};
214 use parking_lot::Mutex;
215
216 use super::*;
217
218 #[derive(Clone, Default)]
219 pub struct StubAgentConnection {
220 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
221 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
222 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
223 }
224
225 struct Session {
226 thread: WeakEntity<AcpThread>,
227 response_tx: Option<oneshot::Sender<acp::StopReason>>,
228 }
229
230 impl StubAgentConnection {
231 pub fn new() -> Self {
232 Self {
233 next_prompt_updates: Default::default(),
234 permission_requests: HashMap::default(),
235 sessions: Arc::default(),
236 }
237 }
238
239 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
240 *self.next_prompt_updates.lock() = updates;
241 }
242
243 pub fn with_permission_requests(
244 mut self,
245 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
246 ) -> Self {
247 self.permission_requests = permission_requests;
248 self
249 }
250
251 pub fn send_update(
252 &self,
253 session_id: acp::SessionId,
254 update: acp::SessionUpdate,
255 cx: &mut App,
256 ) {
257 assert!(
258 self.next_prompt_updates.lock().is_empty(),
259 "Use either send_update or set_next_prompt_updates"
260 );
261
262 self.sessions
263 .lock()
264 .get(&session_id)
265 .unwrap()
266 .thread
267 .update(cx, |thread, cx| {
268 thread.handle_session_update(update, cx).unwrap();
269 })
270 .unwrap();
271 }
272
273 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
274 self.sessions
275 .lock()
276 .get_mut(&session_id)
277 .unwrap()
278 .response_tx
279 .take()
280 .expect("No pending turn")
281 .send(stop_reason)
282 .unwrap();
283 }
284 }
285
286 impl AgentConnection for StubAgentConnection {
287 fn auth_methods(&self) -> &[acp::AuthMethod] {
288 &[]
289 }
290
291 fn new_thread(
292 self: Rc<Self>,
293 project: Entity<Project>,
294 _cwd: &Path,
295 cx: &mut gpui::App,
296 ) -> Task<gpui::Result<Entity<AcpThread>>> {
297 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
298 let thread =
299 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
300 self.sessions.lock().insert(
301 session_id,
302 Session {
303 thread: thread.downgrade(),
304 response_tx: None,
305 },
306 );
307 Task::ready(Ok(thread))
308 }
309
310 fn authenticate(
311 &self,
312 _method_id: acp::AuthMethodId,
313 _cx: &mut App,
314 ) -> Task<gpui::Result<()>> {
315 unimplemented!()
316 }
317
318 fn prompt(
319 &self,
320 _id: Option<UserMessageId>,
321 params: acp::PromptRequest,
322 cx: &mut App,
323 ) -> Task<gpui::Result<acp::PromptResponse>> {
324 let mut sessions = self.sessions.lock();
325 let Session {
326 thread,
327 response_tx,
328 } = sessions.get_mut(¶ms.session_id).unwrap();
329 let mut tasks = vec![];
330 if self.next_prompt_updates.lock().is_empty() {
331 let (tx, rx) = oneshot::channel();
332 response_tx.replace(tx);
333 cx.spawn(async move |_| {
334 let stop_reason = rx.await?;
335 Ok(acp::PromptResponse { stop_reason })
336 })
337 } else {
338 for update in self.next_prompt_updates.lock().drain(..) {
339 let thread = thread.clone();
340 let update = update.clone();
341 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
342 &update
343 && let Some(options) = self.permission_requests.get(&tool_call.id)
344 {
345 Some((tool_call.clone(), options.clone()))
346 } else {
347 None
348 };
349 let task = cx.spawn(async move |cx| {
350 if let Some((tool_call, options)) = permission_request {
351 let permission = thread.update(cx, |thread, cx| {
352 thread.request_tool_call_authorization(
353 tool_call.clone().into(),
354 options.clone(),
355 cx,
356 )
357 })?;
358 permission?.await?;
359 }
360 thread.update(cx, |thread, cx| {
361 thread.handle_session_update(update.clone(), cx).unwrap();
362 })?;
363 anyhow::Ok(())
364 });
365 tasks.push(task);
366 }
367
368 cx.spawn(async move |_| {
369 try_join_all(tasks).await?;
370 Ok(acp::PromptResponse {
371 stop_reason: acp::StopReason::EndTurn,
372 })
373 })
374 }
375 }
376
377 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
378 if let Some(end_turn_tx) = self
379 .sessions
380 .lock()
381 .get_mut(session_id)
382 .unwrap()
383 .response_tx
384 .take()
385 {
386 end_turn_tx.send(acp::StopReason::Canceled).unwrap();
387 }
388 }
389
390 fn session_editor(
391 &self,
392 _session_id: &agent_client_protocol::SessionId,
393 _cx: &mut App,
394 ) -> Option<Rc<dyn AgentSessionEditor>> {
395 Some(Rc::new(StubAgentSessionEditor))
396 }
397
398 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
399 self
400 }
401 }
402
403 struct StubAgentSessionEditor;
404
405 impl AgentSessionEditor for StubAgentSessionEditor {
406 fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
407 Task::ready(Ok(()))
408 }
409 }
410}
411
412#[cfg(feature = "test-support")]
413pub use test_support::*;