1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use language_model::LanguageModelProviderId;
7use project::Project;
8use serde::{Deserialize, Serialize};
9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
10use ui::{App, IconName};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
14pub struct UserMessageId(Arc<str>);
15
16impl UserMessageId {
17 pub fn new() -> Self {
18 Self(Uuid::new_v4().to_string().into())
19 }
20}
21
22pub trait AgentConnection {
23 fn new_thread(
24 self: Rc<Self>,
25 project: Entity<Project>,
26 cwd: &Path,
27 cx: &mut App,
28 ) -> Task<Result<Entity<AcpThread>>>;
29
30 fn auth_methods(&self) -> &[acp::AuthMethod];
31
32 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
33
34 fn prompt(
35 &self,
36 user_message_id: Option<UserMessageId>,
37 params: acp::PromptRequest,
38 cx: &mut App,
39 ) -> Task<Result<acp::PromptResponse>>;
40
41 fn resume(
42 &self,
43 _session_id: &acp::SessionId,
44 _cx: &mut App,
45 ) -> Option<Rc<dyn AgentSessionResume>> {
46 None
47 }
48
49 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
50
51 fn session_editor(
52 &self,
53 _session_id: &acp::SessionId,
54 _cx: &mut App,
55 ) -> Option<Rc<dyn AgentSessionEditor>> {
56 None
57 }
58
59 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
60 ///
61 /// If the agent does not support model selection, returns [None].
62 /// This allows sharing the selector in UI components.
63 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
64 None
65 }
66
67 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
68}
69
70impl dyn AgentConnection {
71 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
72 self.into_any().downcast().ok()
73 }
74}
75
76pub trait AgentSessionEditor {
77 fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
78}
79
80pub trait AgentSessionResume {
81 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
82}
83
84#[derive(Debug)]
85pub struct AuthRequired {
86 pub description: Option<String>,
87 pub provider_id: Option<LanguageModelProviderId>,
88}
89
90impl AuthRequired {
91 pub fn new() -> Self {
92 Self {
93 description: None,
94 provider_id: None,
95 }
96 }
97
98 pub fn with_description(mut self, description: String) -> Self {
99 self.description = Some(description);
100 self
101 }
102
103 pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
104 self.provider_id = Some(provider_id);
105 self
106 }
107}
108
109impl Error for AuthRequired {}
110impl fmt::Display for AuthRequired {
111 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112 write!(f, "Authentication required")
113 }
114}
115
116/// Trait for agents that support listing, selecting, and querying language models.
117///
118/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
119pub trait AgentModelSelector: 'static {
120 /// Lists all available language models for this agent.
121 ///
122 /// # Parameters
123 /// - `cx`: The GPUI app context for async operations and global access.
124 ///
125 /// # Returns
126 /// A task resolving to the list of models or an error (e.g., if no models are configured).
127 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
128
129 /// Selects a model for a specific session (thread).
130 ///
131 /// This sets the default model for future interactions in the session.
132 /// If the session doesn't exist or the model is invalid, it returns an error.
133 ///
134 /// # Parameters
135 /// - `session_id`: The ID of the session (thread) to apply the model to.
136 /// - `model`: The model to select (should be one from [list_models]).
137 /// - `cx`: The GPUI app context.
138 ///
139 /// # Returns
140 /// A task resolving to `Ok(())` on success or an error.
141 fn select_model(
142 &self,
143 session_id: acp::SessionId,
144 model_id: AgentModelId,
145 cx: &mut App,
146 ) -> Task<Result<()>>;
147
148 /// Retrieves the currently selected model for a specific session (thread).
149 ///
150 /// # Parameters
151 /// - `session_id`: The ID of the session (thread) to query.
152 /// - `cx`: The GPUI app context.
153 ///
154 /// # Returns
155 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
156 fn selected_model(
157 &self,
158 session_id: &acp::SessionId,
159 cx: &mut App,
160 ) -> Task<Result<AgentModelInfo>>;
161
162 /// Whenever the model list is updated the receiver will be notified.
163 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
164}
165
166#[derive(Debug, Clone, PartialEq, Eq, Hash)]
167pub struct AgentModelId(pub SharedString);
168
169impl std::ops::Deref for AgentModelId {
170 type Target = SharedString;
171
172 fn deref(&self) -> &Self::Target {
173 &self.0
174 }
175}
176
177impl fmt::Display for AgentModelId {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 self.0.fmt(f)
180 }
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct AgentModelInfo {
185 pub id: AgentModelId,
186 pub name: SharedString,
187 pub icon: Option<IconName>,
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Hash)]
191pub struct AgentModelGroupName(pub SharedString);
192
193#[derive(Debug, Clone)]
194pub enum AgentModelList {
195 Flat(Vec<AgentModelInfo>),
196 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
197}
198
199impl AgentModelList {
200 pub fn is_empty(&self) -> bool {
201 match self {
202 AgentModelList::Flat(models) => models.is_empty(),
203 AgentModelList::Grouped(groups) => groups.is_empty(),
204 }
205 }
206}
207
208#[cfg(feature = "test-support")]
209mod test_support {
210 use std::sync::Arc;
211
212 use action_log::ActionLog;
213 use collections::HashMap;
214 use futures::{channel::oneshot, future::try_join_all};
215 use gpui::{AppContext as _, WeakEntity};
216 use parking_lot::Mutex;
217
218 use super::*;
219
220 #[derive(Clone, Default)]
221 pub struct StubAgentConnection {
222 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
223 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
224 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
225 }
226
227 struct Session {
228 thread: WeakEntity<AcpThread>,
229 response_tx: Option<oneshot::Sender<acp::StopReason>>,
230 }
231
232 impl StubAgentConnection {
233 pub fn new() -> Self {
234 Self {
235 next_prompt_updates: Default::default(),
236 permission_requests: HashMap::default(),
237 sessions: Arc::default(),
238 }
239 }
240
241 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
242 *self.next_prompt_updates.lock() = updates;
243 }
244
245 pub fn with_permission_requests(
246 mut self,
247 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
248 ) -> Self {
249 self.permission_requests = permission_requests;
250 self
251 }
252
253 pub fn send_update(
254 &self,
255 session_id: acp::SessionId,
256 update: acp::SessionUpdate,
257 cx: &mut App,
258 ) {
259 assert!(
260 self.next_prompt_updates.lock().is_empty(),
261 "Use either send_update or set_next_prompt_updates"
262 );
263
264 self.sessions
265 .lock()
266 .get(&session_id)
267 .unwrap()
268 .thread
269 .update(cx, |thread, cx| {
270 thread.handle_session_update(update, cx).unwrap();
271 })
272 .unwrap();
273 }
274
275 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
276 self.sessions
277 .lock()
278 .get_mut(&session_id)
279 .unwrap()
280 .response_tx
281 .take()
282 .expect("No pending turn")
283 .send(stop_reason)
284 .unwrap();
285 }
286 }
287
288 impl AgentConnection for StubAgentConnection {
289 fn auth_methods(&self) -> &[acp::AuthMethod] {
290 &[]
291 }
292
293 fn new_thread(
294 self: Rc<Self>,
295 project: Entity<Project>,
296 _cwd: &Path,
297 cx: &mut gpui::App,
298 ) -> Task<gpui::Result<Entity<AcpThread>>> {
299 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
300 let action_log = cx.new(|_| ActionLog::new(project.clone()));
301 let thread = cx.new(|_cx| {
302 AcpThread::new(
303 "Test",
304 self.clone(),
305 project,
306 action_log,
307 session_id.clone(),
308 )
309 });
310 self.sessions.lock().insert(
311 session_id,
312 Session {
313 thread: thread.downgrade(),
314 response_tx: None,
315 },
316 );
317 Task::ready(Ok(thread))
318 }
319
320 fn authenticate(
321 &self,
322 _method_id: acp::AuthMethodId,
323 _cx: &mut App,
324 ) -> Task<gpui::Result<()>> {
325 unimplemented!()
326 }
327
328 fn prompt(
329 &self,
330 _id: Option<UserMessageId>,
331 params: acp::PromptRequest,
332 cx: &mut App,
333 ) -> Task<gpui::Result<acp::PromptResponse>> {
334 let mut sessions = self.sessions.lock();
335 let Session {
336 thread,
337 response_tx,
338 } = sessions.get_mut(¶ms.session_id).unwrap();
339 let mut tasks = vec![];
340 if self.next_prompt_updates.lock().is_empty() {
341 let (tx, rx) = oneshot::channel();
342 response_tx.replace(tx);
343 cx.spawn(async move |_| {
344 let stop_reason = rx.await?;
345 Ok(acp::PromptResponse { stop_reason })
346 })
347 } else {
348 for update in self.next_prompt_updates.lock().drain(..) {
349 let thread = thread.clone();
350 let update = update.clone();
351 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
352 &update
353 && let Some(options) = self.permission_requests.get(&tool_call.id)
354 {
355 Some((tool_call.clone(), options.clone()))
356 } else {
357 None
358 };
359 let task = cx.spawn(async move |cx| {
360 if let Some((tool_call, options)) = permission_request {
361 let permission = thread.update(cx, |thread, cx| {
362 thread.request_tool_call_authorization(
363 tool_call.clone().into(),
364 options.clone(),
365 cx,
366 )
367 })?;
368 permission?.await?;
369 }
370 thread.update(cx, |thread, cx| {
371 thread.handle_session_update(update.clone(), cx).unwrap();
372 })?;
373 anyhow::Ok(())
374 });
375 tasks.push(task);
376 }
377
378 cx.spawn(async move |_| {
379 try_join_all(tasks).await?;
380 Ok(acp::PromptResponse {
381 stop_reason: acp::StopReason::EndTurn,
382 })
383 })
384 }
385 }
386
387 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
388 if let Some(end_turn_tx) = self
389 .sessions
390 .lock()
391 .get_mut(session_id)
392 .unwrap()
393 .response_tx
394 .take()
395 {
396 end_turn_tx.send(acp::StopReason::Canceled).unwrap();
397 }
398 }
399
400 fn session_editor(
401 &self,
402 _session_id: &agent_client_protocol::SessionId,
403 _cx: &mut App,
404 ) -> Option<Rc<dyn AgentSessionEditor>> {
405 Some(Rc::new(StubAgentSessionEditor))
406 }
407
408 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
409 self
410 }
411 }
412
413 struct StubAgentSessionEditor;
414
415 impl AgentSessionEditor for StubAgentSessionEditor {
416 fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
417 Task::ready(Ok(()))
418 }
419 }
420}
421
422#[cfg(feature = "test-support")]
423pub use test_support::*;