1use crate::{AcpThread, AcpThreadMetadata};
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use futures::channel::mpsc::UnboundedReceiver;
6use gpui::{Entity, SharedString, Task};
7use project::Project;
8use serde::{Deserialize, Serialize};
9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
10use ui::{App, IconName};
11use uuid::Uuid;
12
13#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
14pub struct UserMessageId(Arc<str>);
15
16impl UserMessageId {
17 pub fn new() -> Self {
18 Self(Uuid::new_v4().to_string().into())
19 }
20}
21
22pub trait AgentConnection {
23 fn new_thread(
24 self: Rc<Self>,
25 project: Entity<Project>,
26 cwd: &Path,
27 cx: &mut App,
28 ) -> Task<Result<Entity<AcpThread>>>;
29
30 fn list_threads(
31 &self,
32 _cx: &mut App,
33 ) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
34 return None;
35 }
36
37 fn load_thread(
38 self: Rc<Self>,
39 _project: Entity<Project>,
40 _cwd: &Path,
41 _session_id: acp::SessionId,
42 _cx: &mut App,
43 ) -> Task<Result<Entity<AcpThread>>> {
44 Task::ready(Err(anyhow::anyhow!("load thread not implemented")))
45 }
46
47 fn auth_methods(&self) -> &[acp::AuthMethod];
48
49 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
50
51 fn prompt(
52 &self,
53 user_message_id: Option<UserMessageId>,
54 params: acp::PromptRequest,
55 cx: &mut App,
56 ) -> Task<Result<acp::PromptResponse>>;
57
58 fn resume(
59 &self,
60 _session_id: &acp::SessionId,
61 _cx: &mut App,
62 ) -> Option<Rc<dyn AgentSessionResume>> {
63 None
64 }
65
66 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
67
68 fn session_editor(
69 &self,
70 _session_id: &acp::SessionId,
71 _cx: &mut App,
72 ) -> Option<Rc<dyn AgentSessionEditor>> {
73 None
74 }
75
76 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
77 ///
78 /// If the agent does not support model selection, returns [None].
79 /// This allows sharing the selector in UI components.
80 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
81 None
82 }
83
84 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
85}
86
87impl dyn AgentConnection {
88 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
89 self.into_any().downcast().ok()
90 }
91}
92
93pub trait AgentSessionEditor {
94 fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
95}
96
97pub trait AgentSessionResume {
98 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
99}
100
101#[derive(Debug)]
102pub struct AuthRequired;
103
104impl Error for AuthRequired {}
105impl fmt::Display for AuthRequired {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 write!(f, "AuthRequired")
108 }
109}
110
111/// Trait for agents that support listing, selecting, and querying language models.
112///
113/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
114pub trait AgentModelSelector: 'static {
115 /// Lists all available language models for this agent.
116 ///
117 /// # Parameters
118 /// - `cx`: The GPUI app context for async operations and global access.
119 ///
120 /// # Returns
121 /// A task resolving to the list of models or an error (e.g., if no models are configured).
122 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
123
124 /// Selects a model for a specific session (thread).
125 ///
126 /// This sets the default model for future interactions in the session.
127 /// If the session doesn't exist or the model is invalid, it returns an error.
128 ///
129 /// # Parameters
130 /// - `session_id`: The ID of the session (thread) to apply the model to.
131 /// - `model`: The model to select (should be one from [list_models]).
132 /// - `cx`: The GPUI app context.
133 ///
134 /// # Returns
135 /// A task resolving to `Ok(())` on success or an error.
136 fn select_model(
137 &self,
138 session_id: acp::SessionId,
139 model_id: AgentModelId,
140 cx: &mut App,
141 ) -> Task<Result<()>>;
142
143 /// Retrieves the currently selected model for a specific session (thread).
144 ///
145 /// # Parameters
146 /// - `session_id`: The ID of the session (thread) to query.
147 /// - `cx`: The GPUI app context.
148 ///
149 /// # Returns
150 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
151 fn selected_model(
152 &self,
153 session_id: &acp::SessionId,
154 cx: &mut App,
155 ) -> Task<Result<AgentModelInfo>>;
156
157 /// Whenever the model list is updated the receiver will be notified.
158 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Hash)]
162pub struct AgentModelId(pub SharedString);
163
164impl std::ops::Deref for AgentModelId {
165 type Target = SharedString;
166
167 fn deref(&self) -> &Self::Target {
168 &self.0
169 }
170}
171
172impl fmt::Display for AgentModelId {
173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 self.0.fmt(f)
175 }
176}
177
178#[derive(Debug, Clone, PartialEq, Eq)]
179pub struct AgentModelInfo {
180 pub id: AgentModelId,
181 pub name: SharedString,
182 pub icon: Option<IconName>,
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Hash)]
186pub struct AgentModelGroupName(pub SharedString);
187
188#[derive(Debug, Clone)]
189pub enum AgentModelList {
190 Flat(Vec<AgentModelInfo>),
191 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
192}
193
194impl AgentModelList {
195 pub fn is_empty(&self) -> bool {
196 match self {
197 AgentModelList::Flat(models) => models.is_empty(),
198 AgentModelList::Grouped(groups) => groups.is_empty(),
199 }
200 }
201}
202
203#[cfg(feature = "test-support")]
204mod test_support {
205 use std::sync::Arc;
206
207 use collections::HashMap;
208 use futures::future::try_join_all;
209 use gpui::{AppContext as _, WeakEntity};
210 use parking_lot::Mutex;
211
212 use super::*;
213
214 #[derive(Clone, Default)]
215 pub struct StubAgentConnection {
216 sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
217 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
218 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
219 }
220
221 impl StubAgentConnection {
222 pub fn new() -> Self {
223 Self {
224 next_prompt_updates: Default::default(),
225 permission_requests: HashMap::default(),
226 sessions: Arc::default(),
227 }
228 }
229
230 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
231 *self.next_prompt_updates.lock() = updates;
232 }
233
234 pub fn with_permission_requests(
235 mut self,
236 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
237 ) -> Self {
238 self.permission_requests = permission_requests;
239 self
240 }
241
242 pub fn send_update(
243 &self,
244 session_id: acp::SessionId,
245 update: acp::SessionUpdate,
246 cx: &mut App,
247 ) {
248 self.sessions
249 .lock()
250 .get(&session_id)
251 .unwrap()
252 .update(cx, |thread, cx| {
253 thread.handle_session_update(update.clone(), cx).unwrap();
254 })
255 .unwrap();
256 }
257 }
258
259 impl AgentConnection for StubAgentConnection {
260 fn auth_methods(&self) -> &[acp::AuthMethod] {
261 &[]
262 }
263
264 fn new_thread(
265 self: Rc<Self>,
266 project: Entity<Project>,
267 _cwd: &Path,
268 cx: &mut gpui::App,
269 ) -> Task<gpui::Result<Entity<AcpThread>>> {
270 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
271 let thread =
272 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
273 self.sessions.lock().insert(session_id, thread.downgrade());
274 Task::ready(Ok(thread))
275 }
276
277 fn authenticate(
278 &self,
279 _method_id: acp::AuthMethodId,
280 _cx: &mut App,
281 ) -> Task<gpui::Result<()>> {
282 unimplemented!()
283 }
284
285 fn prompt(
286 &self,
287 _id: Option<UserMessageId>,
288 params: acp::PromptRequest,
289 cx: &mut App,
290 ) -> Task<gpui::Result<acp::PromptResponse>> {
291 let sessions = self.sessions.lock();
292 let thread = sessions.get(¶ms.session_id).unwrap();
293 let mut tasks = vec![];
294 for update in self.next_prompt_updates.lock().drain(..) {
295 let thread = thread.clone();
296 let update = update.clone();
297 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
298 && let Some(options) = self.permission_requests.get(&tool_call.id)
299 {
300 Some((tool_call.clone(), options.clone()))
301 } else {
302 None
303 };
304 let task = cx.spawn(async move |cx| {
305 if let Some((tool_call, options)) = permission_request {
306 let permission = thread.update(cx, |thread, cx| {
307 thread.request_tool_call_authorization(
308 tool_call.clone().into(),
309 options.clone(),
310 cx,
311 )
312 })?;
313 permission?.await?;
314 }
315 thread.update(cx, |thread, cx| {
316 thread.handle_session_update(update.clone(), cx).unwrap();
317 })?;
318 anyhow::Ok(())
319 });
320 tasks.push(task);
321 }
322 cx.spawn(async move |_| {
323 try_join_all(tasks).await?;
324 Ok(acp::PromptResponse {
325 stop_reason: acp::StopReason::EndTurn,
326 })
327 })
328 }
329
330 fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
331 unimplemented!()
332 }
333
334 fn session_editor(
335 &self,
336 _session_id: &agent_client_protocol::SessionId,
337 _cx: &mut App,
338 ) -> Option<Rc<dyn AgentSessionEditor>> {
339 Some(Rc::new(StubAgentSessionEditor))
340 }
341
342 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
343 self
344 }
345 }
346
347 struct StubAgentSessionEditor;
348
349 impl AgentSessionEditor for StubAgentSessionEditor {
350 fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
351 Task::ready(Ok(()))
352 }
353 }
354}
355
356#[cfg(feature = "test-support")]
357pub use test_support::*;