1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use language_model::LanguageModelProviderId;
7use project::Project;
8use serde::{Deserialize, Serialize};
9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
10use ui::{App, IconName};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
14pub struct UserMessageId(Arc<str>);
15
16impl UserMessageId {
17 pub fn new() -> Self {
18 Self(Uuid::new_v4().to_string().into())
19 }
20}
21
22pub trait AgentConnection {
23 fn new_thread(
24 self: Rc<Self>,
25 project: Entity<Project>,
26 cwd: &Path,
27 cx: &mut App,
28 ) -> Task<Result<Entity<AcpThread>>>;
29
30 fn auth_methods(&self) -> &[acp::AuthMethod];
31
32 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
33
34 fn prompt(
35 &self,
36 user_message_id: Option<UserMessageId>,
37 params: acp::PromptRequest,
38 cx: &mut App,
39 ) -> Task<Result<acp::PromptResponse>>;
40
41 fn prompt_capabilities(&self) -> acp::PromptCapabilities;
42
43 fn resume(
44 &self,
45 _session_id: &acp::SessionId,
46 _cx: &mut App,
47 ) -> Option<Rc<dyn AgentSessionResume>> {
48 None
49 }
50
51 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
52
53 fn session_editor(
54 &self,
55 _session_id: &acp::SessionId,
56 _cx: &mut App,
57 ) -> Option<Rc<dyn AgentSessionEditor>> {
58 None
59 }
60
61 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
62 ///
63 /// If the agent does not support model selection, returns [None].
64 /// This allows sharing the selector in UI components.
65 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
66 None
67 }
68
69 fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
70 None
71 }
72
73 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
74}
75
76impl dyn AgentConnection {
77 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
78 self.into_any().downcast().ok()
79 }
80}
81
82pub trait AgentSessionEditor {
83 fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
84}
85
86pub trait AgentSessionResume {
87 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
88}
89
90pub trait AgentTelemetry {
91 /// The name of the agent used for telemetry.
92 fn agent_name(&self) -> String;
93
94 /// A representation of the current thread state that can be serialized for
95 /// storage with telemetry events.
96 fn thread_data(
97 &self,
98 session_id: &acp::SessionId,
99 cx: &mut App,
100 ) -> Task<Result<serde_json::Value>>;
101}
102
103#[derive(Debug)]
104pub struct AuthRequired {
105 pub description: Option<String>,
106 pub provider_id: Option<LanguageModelProviderId>,
107}
108
109impl AuthRequired {
110 pub fn new() -> Self {
111 Self {
112 description: None,
113 provider_id: None,
114 }
115 }
116
117 pub fn with_description(mut self, description: String) -> Self {
118 self.description = Some(description);
119 self
120 }
121
122 pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
123 self.provider_id = Some(provider_id);
124 self
125 }
126}
127
128impl Error for AuthRequired {}
129impl fmt::Display for AuthRequired {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 write!(f, "Authentication required")
132 }
133}
134
135/// Trait for agents that support listing, selecting, and querying language models.
136///
137/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
138pub trait AgentModelSelector: 'static {
139 /// Lists all available language models for this agent.
140 ///
141 /// # Parameters
142 /// - `cx`: The GPUI app context for async operations and global access.
143 ///
144 /// # Returns
145 /// A task resolving to the list of models or an error (e.g., if no models are configured).
146 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
147
148 /// Selects a model for a specific session (thread).
149 ///
150 /// This sets the default model for future interactions in the session.
151 /// If the session doesn't exist or the model is invalid, it returns an error.
152 ///
153 /// # Parameters
154 /// - `session_id`: The ID of the session (thread) to apply the model to.
155 /// - `model`: The model to select (should be one from [list_models]).
156 /// - `cx`: The GPUI app context.
157 ///
158 /// # Returns
159 /// A task resolving to `Ok(())` on success or an error.
160 fn select_model(
161 &self,
162 session_id: acp::SessionId,
163 model_id: AgentModelId,
164 cx: &mut App,
165 ) -> Task<Result<()>>;
166
167 /// Retrieves the currently selected model for a specific session (thread).
168 ///
169 /// # Parameters
170 /// - `session_id`: The ID of the session (thread) to query.
171 /// - `cx`: The GPUI app context.
172 ///
173 /// # Returns
174 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
175 fn selected_model(
176 &self,
177 session_id: &acp::SessionId,
178 cx: &mut App,
179 ) -> Task<Result<AgentModelInfo>>;
180
181 /// Whenever the model list is updated the receiver will be notified.
182 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Hash)]
186pub struct AgentModelId(pub SharedString);
187
188impl std::ops::Deref for AgentModelId {
189 type Target = SharedString;
190
191 fn deref(&self) -> &Self::Target {
192 &self.0
193 }
194}
195
196impl fmt::Display for AgentModelId {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 self.0.fmt(f)
199 }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
203pub struct AgentModelInfo {
204 pub id: AgentModelId,
205 pub name: SharedString,
206 pub icon: Option<IconName>,
207}
208
209#[derive(Debug, Clone, PartialEq, Eq, Hash)]
210pub struct AgentModelGroupName(pub SharedString);
211
212#[derive(Debug, Clone)]
213pub enum AgentModelList {
214 Flat(Vec<AgentModelInfo>),
215 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
216}
217
218impl AgentModelList {
219 pub fn is_empty(&self) -> bool {
220 match self {
221 AgentModelList::Flat(models) => models.is_empty(),
222 AgentModelList::Grouped(groups) => groups.is_empty(),
223 }
224 }
225}
226
227#[cfg(feature = "test-support")]
228mod test_support {
229 use std::sync::Arc;
230
231 use action_log::ActionLog;
232 use collections::HashMap;
233 use futures::{channel::oneshot, future::try_join_all};
234 use gpui::{AppContext as _, WeakEntity};
235 use parking_lot::Mutex;
236
237 use super::*;
238
239 #[derive(Clone, Default)]
240 pub struct StubAgentConnection {
241 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
242 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
243 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
244 }
245
246 struct Session {
247 thread: WeakEntity<AcpThread>,
248 response_tx: Option<oneshot::Sender<acp::StopReason>>,
249 }
250
251 impl StubAgentConnection {
252 pub fn new() -> Self {
253 Self {
254 next_prompt_updates: Default::default(),
255 permission_requests: HashMap::default(),
256 sessions: Arc::default(),
257 }
258 }
259
260 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
261 *self.next_prompt_updates.lock() = updates;
262 }
263
264 pub fn with_permission_requests(
265 mut self,
266 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
267 ) -> Self {
268 self.permission_requests = permission_requests;
269 self
270 }
271
272 pub fn send_update(
273 &self,
274 session_id: acp::SessionId,
275 update: acp::SessionUpdate,
276 cx: &mut App,
277 ) {
278 assert!(
279 self.next_prompt_updates.lock().is_empty(),
280 "Use either send_update or set_next_prompt_updates"
281 );
282
283 self.sessions
284 .lock()
285 .get(&session_id)
286 .unwrap()
287 .thread
288 .update(cx, |thread, cx| {
289 thread.handle_session_update(update, cx).unwrap();
290 })
291 .unwrap();
292 }
293
294 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
295 self.sessions
296 .lock()
297 .get_mut(&session_id)
298 .unwrap()
299 .response_tx
300 .take()
301 .expect("No pending turn")
302 .send(stop_reason)
303 .unwrap();
304 }
305 }
306
307 impl AgentConnection for StubAgentConnection {
308 fn auth_methods(&self) -> &[acp::AuthMethod] {
309 &[]
310 }
311
312 fn new_thread(
313 self: Rc<Self>,
314 project: Entity<Project>,
315 _cwd: &Path,
316 cx: &mut gpui::App,
317 ) -> Task<gpui::Result<Entity<AcpThread>>> {
318 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
319 let action_log = cx.new(|_| ActionLog::new(project.clone()));
320 let thread = cx.new(|_cx| {
321 AcpThread::new(
322 "Test",
323 self.clone(),
324 project,
325 action_log,
326 session_id.clone(),
327 )
328 });
329 self.sessions.lock().insert(
330 session_id,
331 Session {
332 thread: thread.downgrade(),
333 response_tx: None,
334 },
335 );
336 Task::ready(Ok(thread))
337 }
338
339 fn prompt_capabilities(&self) -> acp::PromptCapabilities {
340 acp::PromptCapabilities {
341 image: true,
342 audio: true,
343 embedded_context: true,
344 }
345 }
346
347 fn authenticate(
348 &self,
349 _method_id: acp::AuthMethodId,
350 _cx: &mut App,
351 ) -> Task<gpui::Result<()>> {
352 unimplemented!()
353 }
354
355 fn prompt(
356 &self,
357 _id: Option<UserMessageId>,
358 params: acp::PromptRequest,
359 cx: &mut App,
360 ) -> Task<gpui::Result<acp::PromptResponse>> {
361 let mut sessions = self.sessions.lock();
362 let Session {
363 thread,
364 response_tx,
365 } = sessions.get_mut(¶ms.session_id).unwrap();
366 let mut tasks = vec![];
367 if self.next_prompt_updates.lock().is_empty() {
368 let (tx, rx) = oneshot::channel();
369 response_tx.replace(tx);
370 cx.spawn(async move |_| {
371 let stop_reason = rx.await?;
372 Ok(acp::PromptResponse { stop_reason })
373 })
374 } else {
375 for update in self.next_prompt_updates.lock().drain(..) {
376 let thread = thread.clone();
377 let update = update.clone();
378 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
379 &update
380 && let Some(options) = self.permission_requests.get(&tool_call.id)
381 {
382 Some((tool_call.clone(), options.clone()))
383 } else {
384 None
385 };
386 let task = cx.spawn(async move |cx| {
387 if let Some((tool_call, options)) = permission_request {
388 let permission = thread.update(cx, |thread, cx| {
389 thread.request_tool_call_authorization(
390 tool_call.clone().into(),
391 options.clone(),
392 cx,
393 )
394 })?;
395 permission?.await?;
396 }
397 thread.update(cx, |thread, cx| {
398 thread.handle_session_update(update.clone(), cx).unwrap();
399 })?;
400 anyhow::Ok(())
401 });
402 tasks.push(task);
403 }
404
405 cx.spawn(async move |_| {
406 try_join_all(tasks).await?;
407 Ok(acp::PromptResponse {
408 stop_reason: acp::StopReason::EndTurn,
409 })
410 })
411 }
412 }
413
414 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
415 if let Some(end_turn_tx) = self
416 .sessions
417 .lock()
418 .get_mut(session_id)
419 .unwrap()
420 .response_tx
421 .take()
422 {
423 end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
424 }
425 }
426
427 fn session_editor(
428 &self,
429 _session_id: &agent_client_protocol::SessionId,
430 _cx: &mut App,
431 ) -> Option<Rc<dyn AgentSessionEditor>> {
432 Some(Rc::new(StubAgentSessionEditor))
433 }
434
435 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
436 self
437 }
438 }
439
440 struct StubAgentSessionEditor;
441
442 impl AgentSessionEditor for StubAgentSessionEditor {
443 fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
444 Task::ready(Ok(()))
445 }
446 }
447}
448
449#[cfg(feature = "test-support")]
450pub use test_support::*;