1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use language_model::LanguageModelProviderId;
7use project::Project;
8use serde::{Deserialize, Serialize};
9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
10use ui::{App, IconName};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
14pub struct UserMessageId(Arc<str>);
15
16impl UserMessageId {
17 pub fn new() -> Self {
18 Self(Uuid::new_v4().to_string().into())
19 }
20}
21
22pub trait AgentConnection {
23 fn telemetry_id(&self) -> SharedString;
24
25 fn new_thread(
26 self: Rc<Self>,
27 project: Entity<Project>,
28 cwd: &Path,
29 cx: &mut App,
30 ) -> Task<Result<Entity<AcpThread>>>;
31
32 fn auth_methods(&self) -> &[acp::AuthMethod];
33
34 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
35
36 fn prompt(
37 &self,
38 user_message_id: Option<UserMessageId>,
39 params: acp::PromptRequest,
40 cx: &mut App,
41 ) -> Task<Result<acp::PromptResponse>>;
42
43 fn resume(
44 &self,
45 _session_id: &acp::SessionId,
46 _cx: &App,
47 ) -> Option<Rc<dyn AgentSessionResume>> {
48 None
49 }
50
51 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
52
53 fn truncate(
54 &self,
55 _session_id: &acp::SessionId,
56 _cx: &App,
57 ) -> Option<Rc<dyn AgentSessionTruncate>> {
58 None
59 }
60
61 fn set_title(
62 &self,
63 _session_id: &acp::SessionId,
64 _cx: &App,
65 ) -> Option<Rc<dyn AgentSessionSetTitle>> {
66 None
67 }
68
69 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
70 ///
71 /// If the agent does not support model selection, returns [None].
72 /// This allows sharing the selector in UI components.
73 fn model_selector(&self, _session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
74 None
75 }
76
77 fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
78 None
79 }
80
81 fn session_modes(
82 &self,
83 _session_id: &acp::SessionId,
84 _cx: &App,
85 ) -> Option<Rc<dyn AgentSessionModes>> {
86 None
87 }
88
89 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
90}
91
92impl dyn AgentConnection {
93 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
94 self.into_any().downcast().ok()
95 }
96}
97
98pub trait AgentSessionTruncate {
99 fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
100}
101
102pub trait AgentSessionResume {
103 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
104}
105
106pub trait AgentSessionSetTitle {
107 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
108}
109
110pub trait AgentTelemetry {
111 /// A representation of the current thread state that can be serialized for
112 /// storage with telemetry events.
113 fn thread_data(
114 &self,
115 session_id: &acp::SessionId,
116 cx: &mut App,
117 ) -> Task<Result<serde_json::Value>>;
118}
119
120pub trait AgentSessionModes {
121 fn current_mode(&self) -> acp::SessionModeId;
122
123 fn all_modes(&self) -> Vec<acp::SessionMode>;
124
125 fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
126}
127
128#[derive(Debug)]
129pub struct AuthRequired {
130 pub description: Option<String>,
131 pub provider_id: Option<LanguageModelProviderId>,
132}
133
134impl AuthRequired {
135 pub fn new() -> Self {
136 Self {
137 description: None,
138 provider_id: None,
139 }
140 }
141
142 pub fn with_description(mut self, description: String) -> Self {
143 self.description = Some(description);
144 self
145 }
146
147 pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
148 self.provider_id = Some(provider_id);
149 self
150 }
151}
152
153impl Error for AuthRequired {}
154impl fmt::Display for AuthRequired {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 write!(f, "Authentication required")
157 }
158}
159
160/// Trait for agents that support listing, selecting, and querying language models.
161///
162/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
163pub trait AgentModelSelector: 'static {
164 /// Lists all available language models for this agent.
165 ///
166 /// # Parameters
167 /// - `cx`: The GPUI app context for async operations and global access.
168 ///
169 /// # Returns
170 /// A task resolving to the list of models or an error (e.g., if no models are configured).
171 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
172
173 /// Selects a model for a specific session (thread).
174 ///
175 /// This sets the default model for future interactions in the session.
176 /// If the session doesn't exist or the model is invalid, it returns an error.
177 ///
178 /// # Parameters
179 /// - `model`: The model to select (should be one from [list_models]).
180 /// - `cx`: The GPUI app context.
181 ///
182 /// # Returns
183 /// A task resolving to `Ok(())` on success or an error.
184 fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>>;
185
186 /// Retrieves the currently selected model for a specific session (thread).
187 ///
188 /// # Parameters
189 /// - `cx`: The GPUI app context.
190 ///
191 /// # Returns
192 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
193 fn selected_model(&self, cx: &mut App) -> Task<Result<AgentModelInfo>>;
194
195 /// Whenever the model list is updated the receiver will be notified.
196 /// Optional for agents that don't update their model list.
197 fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
198 None
199 }
200
201 /// Returns whether the model picker should render a footer.
202 fn should_render_footer(&self) -> bool {
203 false
204 }
205
206 /// Whether this selector supports the favorites feature.
207 /// Only the native agent uses the model ID format that maps to settings.
208 fn supports_favorites(&self) -> bool {
209 false
210 }
211}
212
213/// Icon for a model in the model selector.
214#[derive(Debug, Clone, PartialEq, Eq)]
215pub enum AgentModelIcon {
216 /// A built-in icon from Zed's icon set.
217 Named(IconName),
218 /// Path to a custom SVG icon file.
219 Path(SharedString),
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
223pub struct AgentModelInfo {
224 pub id: acp::ModelId,
225 pub name: SharedString,
226 pub description: Option<SharedString>,
227 pub icon: Option<AgentModelIcon>,
228}
229
230impl From<acp::ModelInfo> for AgentModelInfo {
231 fn from(info: acp::ModelInfo) -> Self {
232 Self {
233 id: info.model_id,
234 name: info.name.into(),
235 description: info.description.map(|desc| desc.into()),
236 icon: None,
237 }
238 }
239}
240
241#[derive(Debug, Clone, PartialEq, Eq, Hash)]
242pub struct AgentModelGroupName(pub SharedString);
243
244#[derive(Debug, Clone)]
245pub enum AgentModelList {
246 Flat(Vec<AgentModelInfo>),
247 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
248}
249
250impl AgentModelList {
251 pub fn is_empty(&self) -> bool {
252 match self {
253 AgentModelList::Flat(models) => models.is_empty(),
254 AgentModelList::Grouped(groups) => groups.is_empty(),
255 }
256 }
257
258 pub fn is_flat(&self) -> bool {
259 matches!(self, AgentModelList::Flat(_))
260 }
261}
262
263#[cfg(feature = "test-support")]
264mod test_support {
265 use std::sync::Arc;
266
267 use action_log::ActionLog;
268 use collections::HashMap;
269 use futures::{channel::oneshot, future::try_join_all};
270 use gpui::{AppContext as _, WeakEntity};
271 use parking_lot::Mutex;
272
273 use super::*;
274
275 #[derive(Clone, Default)]
276 pub struct StubAgentConnection {
277 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
278 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
279 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
280 }
281
282 struct Session {
283 thread: WeakEntity<AcpThread>,
284 response_tx: Option<oneshot::Sender<acp::StopReason>>,
285 }
286
287 impl StubAgentConnection {
288 pub fn new() -> Self {
289 Self {
290 next_prompt_updates: Default::default(),
291 permission_requests: HashMap::default(),
292 sessions: Arc::default(),
293 }
294 }
295
296 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
297 *self.next_prompt_updates.lock() = updates;
298 }
299
300 pub fn with_permission_requests(
301 mut self,
302 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
303 ) -> Self {
304 self.permission_requests = permission_requests;
305 self
306 }
307
308 pub fn send_update(
309 &self,
310 session_id: acp::SessionId,
311 update: acp::SessionUpdate,
312 cx: &mut App,
313 ) {
314 assert!(
315 self.next_prompt_updates.lock().is_empty(),
316 "Use either send_update or set_next_prompt_updates"
317 );
318
319 self.sessions
320 .lock()
321 .get(&session_id)
322 .unwrap()
323 .thread
324 .update(cx, |thread, cx| {
325 thread.handle_session_update(update, cx).unwrap();
326 })
327 .unwrap();
328 }
329
330 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
331 self.sessions
332 .lock()
333 .get_mut(&session_id)
334 .unwrap()
335 .response_tx
336 .take()
337 .expect("No pending turn")
338 .send(stop_reason)
339 .unwrap();
340 }
341 }
342
343 impl AgentConnection for StubAgentConnection {
344 fn telemetry_id(&self) -> SharedString {
345 "stub".into()
346 }
347
348 fn auth_methods(&self) -> &[acp::AuthMethod] {
349 &[]
350 }
351
352 fn new_thread(
353 self: Rc<Self>,
354 project: Entity<Project>,
355 _cwd: &Path,
356 cx: &mut gpui::App,
357 ) -> Task<gpui::Result<Entity<AcpThread>>> {
358 let session_id = acp::SessionId::new(self.sessions.lock().len().to_string());
359 let action_log = cx.new(|_| ActionLog::new(project.clone()));
360 let thread = cx.new(|cx| {
361 AcpThread::new(
362 "Test",
363 self.clone(),
364 project,
365 action_log,
366 session_id.clone(),
367 watch::Receiver::constant(
368 acp::PromptCapabilities::new()
369 .image(true)
370 .audio(true)
371 .embedded_context(true),
372 ),
373 cx,
374 )
375 });
376 self.sessions.lock().insert(
377 session_id,
378 Session {
379 thread: thread.downgrade(),
380 response_tx: None,
381 },
382 );
383 Task::ready(Ok(thread))
384 }
385
386 fn authenticate(
387 &self,
388 _method_id: acp::AuthMethodId,
389 _cx: &mut App,
390 ) -> Task<gpui::Result<()>> {
391 unimplemented!()
392 }
393
394 fn prompt(
395 &self,
396 _id: Option<UserMessageId>,
397 params: acp::PromptRequest,
398 cx: &mut App,
399 ) -> Task<gpui::Result<acp::PromptResponse>> {
400 let mut sessions = self.sessions.lock();
401 let Session {
402 thread,
403 response_tx,
404 } = sessions.get_mut(¶ms.session_id).unwrap();
405 let mut tasks = vec![];
406 if self.next_prompt_updates.lock().is_empty() {
407 let (tx, rx) = oneshot::channel();
408 response_tx.replace(tx);
409 cx.spawn(async move |_| {
410 let stop_reason = rx.await?;
411 Ok(acp::PromptResponse::new(stop_reason))
412 })
413 } else {
414 for update in self.next_prompt_updates.lock().drain(..) {
415 let thread = thread.clone();
416 let update = update.clone();
417 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
418 &update
419 && let Some(options) = self.permission_requests.get(&tool_call.tool_call_id)
420 {
421 Some((tool_call.clone(), options.clone()))
422 } else {
423 None
424 };
425 let task = cx.spawn(async move |cx| {
426 if let Some((tool_call, options)) = permission_request {
427 thread
428 .update(cx, |thread, cx| {
429 thread.request_tool_call_authorization(
430 tool_call.clone().into(),
431 options.clone(),
432 false,
433 cx,
434 )
435 })??
436 .await;
437 }
438 thread.update(cx, |thread, cx| {
439 thread.handle_session_update(update.clone(), cx).unwrap();
440 })?;
441 anyhow::Ok(())
442 });
443 tasks.push(task);
444 }
445
446 cx.spawn(async move |_| {
447 try_join_all(tasks).await?;
448 Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
449 })
450 }
451 }
452
453 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
454 if let Some(end_turn_tx) = self
455 .sessions
456 .lock()
457 .get_mut(session_id)
458 .unwrap()
459 .response_tx
460 .take()
461 {
462 end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
463 }
464 }
465
466 fn truncate(
467 &self,
468 _session_id: &agent_client_protocol::SessionId,
469 _cx: &App,
470 ) -> Option<Rc<dyn AgentSessionTruncate>> {
471 Some(Rc::new(StubAgentSessionEditor))
472 }
473
474 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
475 self
476 }
477 }
478
479 struct StubAgentSessionEditor;
480
481 impl AgentSessionTruncate for StubAgentSessionEditor {
482 fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
483 Task::ready(Ok(()))
484 }
485 }
486}
487
488#[cfg(feature = "test-support")]
489pub use test_support::*;