1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use language_model::LanguageModelProviderId;
7use project::Project;
8use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
9use ui::{App, IconName};
10use uuid::Uuid;
11
12#[derive(Clone, Debug, Eq, PartialEq)]
13pub struct UserMessageId(Arc<str>);
14
15impl UserMessageId {
16 pub fn new() -> Self {
17 Self(Uuid::new_v4().to_string().into())
18 }
19}
20
21pub trait AgentConnection {
22 fn new_thread(
23 self: Rc<Self>,
24 project: Entity<Project>,
25 cwd: &Path,
26 cx: &mut App,
27 ) -> Task<Result<Entity<AcpThread>>>;
28
29 fn auth_methods(&self) -> &[acp::AuthMethod];
30
31 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
32
33 fn prompt(
34 &self,
35 user_message_id: Option<UserMessageId>,
36 params: acp::PromptRequest,
37 cx: &mut App,
38 ) -> Task<Result<acp::PromptResponse>>;
39
40 fn resume(
41 &self,
42 _session_id: &acp::SessionId,
43 _cx: &mut App,
44 ) -> Option<Rc<dyn AgentSessionResume>> {
45 None
46 }
47
48 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
49
50 fn session_editor(
51 &self,
52 _session_id: &acp::SessionId,
53 _cx: &mut App,
54 ) -> Option<Rc<dyn AgentSessionEditor>> {
55 None
56 }
57
58 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
59 ///
60 /// If the agent does not support model selection, returns [None].
61 /// This allows sharing the selector in UI components.
62 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
63 None
64 }
65
66 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
67}
68
69impl dyn AgentConnection {
70 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
71 self.into_any().downcast().ok()
72 }
73}
74
75pub trait AgentSessionEditor {
76 fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
77}
78
79pub trait AgentSessionResume {
80 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
81}
82
83#[derive(Debug)]
84pub struct AuthRequired {
85 pub description: Option<String>,
86 pub provider_id: Option<LanguageModelProviderId>,
87}
88
89impl AuthRequired {
90 pub fn new() -> Self {
91 Self {
92 description: None,
93 provider_id: None,
94 }
95 }
96
97 pub fn with_description(mut self, description: String) -> Self {
98 self.description = Some(description);
99 self
100 }
101
102 pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
103 self.provider_id = Some(provider_id);
104 self
105 }
106}
107
108impl Error for AuthRequired {}
109impl fmt::Display for AuthRequired {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 write!(f, "Authentication required")
112 }
113}
114
115/// Trait for agents that support listing, selecting, and querying language models.
116///
117/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
118pub trait AgentModelSelector: 'static {
119 /// Lists all available language models for this agent.
120 ///
121 /// # Parameters
122 /// - `cx`: The GPUI app context for async operations and global access.
123 ///
124 /// # Returns
125 /// A task resolving to the list of models or an error (e.g., if no models are configured).
126 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
127
128 /// Selects a model for a specific session (thread).
129 ///
130 /// This sets the default model for future interactions in the session.
131 /// If the session doesn't exist or the model is invalid, it returns an error.
132 ///
133 /// # Parameters
134 /// - `session_id`: The ID of the session (thread) to apply the model to.
135 /// - `model`: The model to select (should be one from [list_models]).
136 /// - `cx`: The GPUI app context.
137 ///
138 /// # Returns
139 /// A task resolving to `Ok(())` on success or an error.
140 fn select_model(
141 &self,
142 session_id: acp::SessionId,
143 model_id: AgentModelId,
144 cx: &mut App,
145 ) -> Task<Result<()>>;
146
147 /// Retrieves the currently selected model for a specific session (thread).
148 ///
149 /// # Parameters
150 /// - `session_id`: The ID of the session (thread) to query.
151 /// - `cx`: The GPUI app context.
152 ///
153 /// # Returns
154 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
155 fn selected_model(
156 &self,
157 session_id: &acp::SessionId,
158 cx: &mut App,
159 ) -> Task<Result<AgentModelInfo>>;
160
161 /// Whenever the model list is updated the receiver will be notified.
162 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
163}
164
165#[derive(Debug, Clone, PartialEq, Eq, Hash)]
166pub struct AgentModelId(pub SharedString);
167
168impl std::ops::Deref for AgentModelId {
169 type Target = SharedString;
170
171 fn deref(&self) -> &Self::Target {
172 &self.0
173 }
174}
175
176impl fmt::Display for AgentModelId {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 self.0.fmt(f)
179 }
180}
181
182#[derive(Debug, Clone, PartialEq, Eq)]
183pub struct AgentModelInfo {
184 pub id: AgentModelId,
185 pub name: SharedString,
186 pub icon: Option<IconName>,
187}
188
189#[derive(Debug, Clone, PartialEq, Eq, Hash)]
190pub struct AgentModelGroupName(pub SharedString);
191
192#[derive(Debug, Clone)]
193pub enum AgentModelList {
194 Flat(Vec<AgentModelInfo>),
195 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
196}
197
198impl AgentModelList {
199 pub fn is_empty(&self) -> bool {
200 match self {
201 AgentModelList::Flat(models) => models.is_empty(),
202 AgentModelList::Grouped(groups) => groups.is_empty(),
203 }
204 }
205}
206
207#[cfg(feature = "test-support")]
208mod test_support {
209 use std::sync::Arc;
210
211 use collections::HashMap;
212 use futures::{channel::oneshot, future::try_join_all};
213 use gpui::{AppContext as _, WeakEntity};
214 use parking_lot::Mutex;
215
216 use super::*;
217
218 #[derive(Clone, Default)]
219 pub struct StubAgentConnection {
220 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
221 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
222 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
223 }
224
225 struct Session {
226 thread: WeakEntity<AcpThread>,
227 response_tx: Option<oneshot::Sender<()>>,
228 }
229
230 impl StubAgentConnection {
231 pub fn new() -> Self {
232 Self {
233 next_prompt_updates: Default::default(),
234 permission_requests: HashMap::default(),
235 sessions: Arc::default(),
236 }
237 }
238
239 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
240 *self.next_prompt_updates.lock() = updates;
241 }
242
243 pub fn with_permission_requests(
244 mut self,
245 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
246 ) -> Self {
247 self.permission_requests = permission_requests;
248 self
249 }
250
251 pub fn send_update(
252 &self,
253 session_id: acp::SessionId,
254 update: acp::SessionUpdate,
255 cx: &mut App,
256 ) {
257 assert!(
258 self.next_prompt_updates.lock().is_empty(),
259 "Use either send_update or set_next_prompt_updates"
260 );
261
262 self.sessions
263 .lock()
264 .get(&session_id)
265 .unwrap()
266 .thread
267 .update(cx, |thread, cx| {
268 thread.handle_session_update(update.clone(), cx).unwrap();
269 })
270 .unwrap();
271 }
272
273 pub fn end_turn(&self, session_id: acp::SessionId) {
274 self.sessions
275 .lock()
276 .get_mut(&session_id)
277 .unwrap()
278 .response_tx
279 .take()
280 .expect("No pending turn")
281 .send(())
282 .unwrap();
283 }
284 }
285
286 impl AgentConnection for StubAgentConnection {
287 fn auth_methods(&self) -> &[acp::AuthMethod] {
288 &[]
289 }
290
291 fn new_thread(
292 self: Rc<Self>,
293 project: Entity<Project>,
294 _cwd: &Path,
295 cx: &mut gpui::App,
296 ) -> Task<gpui::Result<Entity<AcpThread>>> {
297 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
298 let thread =
299 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
300 self.sessions.lock().insert(
301 session_id,
302 Session {
303 thread: thread.downgrade(),
304 response_tx: None,
305 },
306 );
307 Task::ready(Ok(thread))
308 }
309
310 fn authenticate(
311 &self,
312 _method_id: acp::AuthMethodId,
313 _cx: &mut App,
314 ) -> Task<gpui::Result<()>> {
315 unimplemented!()
316 }
317
318 fn prompt(
319 &self,
320 _id: Option<UserMessageId>,
321 params: acp::PromptRequest,
322 cx: &mut App,
323 ) -> Task<gpui::Result<acp::PromptResponse>> {
324 let mut sessions = self.sessions.lock();
325 let Session {
326 thread,
327 response_tx,
328 } = sessions.get_mut(¶ms.session_id).unwrap();
329 let mut tasks = vec![];
330 if self.next_prompt_updates.lock().is_empty() {
331 let (tx, rx) = oneshot::channel();
332 response_tx.replace(tx);
333 cx.spawn(async move |_| {
334 rx.await?;
335 Ok(acp::PromptResponse {
336 stop_reason: acp::StopReason::EndTurn,
337 })
338 })
339 } else {
340 for update in self.next_prompt_updates.lock().drain(..) {
341 let thread = thread.clone();
342 let update = update.clone();
343 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
344 &update
345 && let Some(options) = self.permission_requests.get(&tool_call.id)
346 {
347 Some((tool_call.clone(), options.clone()))
348 } else {
349 None
350 };
351 let task = cx.spawn(async move |cx| {
352 if let Some((tool_call, options)) = permission_request {
353 let permission = thread.update(cx, |thread, cx| {
354 thread.request_tool_call_authorization(
355 tool_call.clone().into(),
356 options.clone(),
357 cx,
358 )
359 })?;
360 permission?.await?;
361 }
362 thread.update(cx, |thread, cx| {
363 thread.handle_session_update(update.clone(), cx).unwrap();
364 })?;
365 anyhow::Ok(())
366 });
367 tasks.push(task);
368 }
369
370 cx.spawn(async move |_| {
371 try_join_all(tasks).await?;
372 Ok(acp::PromptResponse {
373 stop_reason: acp::StopReason::EndTurn,
374 })
375 })
376 }
377 }
378
379 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
380 unimplemented!()
381 }
382
383 fn session_editor(
384 &self,
385 _session_id: &agent_client_protocol::SessionId,
386 _cx: &mut App,
387 ) -> Option<Rc<dyn AgentSessionEditor>> {
388 Some(Rc::new(StubAgentSessionEditor))
389 }
390
391 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
392 self
393 }
394 }
395
396 struct StubAgentSessionEditor;
397
398 impl AgentSessionEditor for StubAgentSessionEditor {
399 fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
400 Task::ready(Ok(()))
401 }
402 }
403}
404
405#[cfg(feature = "test-support")]
406pub use test_support::*;