1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use language_model::LanguageModelProviderId;
7use project::Project;
8use serde::{Deserialize, Serialize};
9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
10use ui::{App, IconName};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
14pub struct UserMessageId(Arc<str>);
15
16impl UserMessageId {
17 pub fn new() -> Self {
18 Self(Uuid::new_v4().to_string().into())
19 }
20}
21
22pub trait AgentConnection {
23 fn new_thread(
24 self: Rc<Self>,
25 project: Entity<Project>,
26 cwd: &Path,
27 cx: &mut App,
28 ) -> Task<Result<Entity<AcpThread>>>;
29
30 fn auth_methods(&self) -> &[acp::AuthMethod];
31
32 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
33
34 fn prompt(
35 &self,
36 user_message_id: Option<UserMessageId>,
37 params: acp::PromptRequest,
38 cx: &mut App,
39 ) -> Task<Result<acp::PromptResponse>>;
40
41 fn resume(
42 &self,
43 _session_id: &acp::SessionId,
44 _cx: &App,
45 ) -> Option<Rc<dyn AgentSessionResume>> {
46 None
47 }
48
49 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
50
51 fn truncate(
52 &self,
53 _session_id: &acp::SessionId,
54 _cx: &App,
55 ) -> Option<Rc<dyn AgentSessionTruncate>> {
56 None
57 }
58
59 fn set_title(
60 &self,
61 _session_id: &acp::SessionId,
62 _cx: &App,
63 ) -> Option<Rc<dyn AgentSessionSetTitle>> {
64 None
65 }
66
67 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
68 ///
69 /// If the agent does not support model selection, returns [None].
70 /// This allows sharing the selector in UI components.
71 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
72 None
73 }
74
75 fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
76 None
77 }
78
79 fn session_modes(
80 &self,
81 _session_id: &acp::SessionId,
82 _cx: &App,
83 ) -> Option<Rc<dyn AgentSessionModes>> {
84 None
85 }
86
87 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
88}
89
90impl dyn AgentConnection {
91 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
92 self.into_any().downcast().ok()
93 }
94}
95
96pub trait AgentSessionTruncate {
97 fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
98}
99
100pub trait AgentSessionResume {
101 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
102}
103
104pub trait AgentSessionSetTitle {
105 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
106}
107
108pub trait AgentTelemetry {
109 /// The name of the agent used for telemetry.
110 fn agent_name(&self) -> String;
111
112 /// A representation of the current thread state that can be serialized for
113 /// storage with telemetry events.
114 fn thread_data(
115 &self,
116 session_id: &acp::SessionId,
117 cx: &mut App,
118 ) -> Task<Result<serde_json::Value>>;
119}
120
121pub trait AgentSessionModes {
122 fn current_mode(&self) -> acp::SessionModeId;
123
124 fn all_modes(&self) -> Vec<acp::SessionMode>;
125
126 fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
127}
128
129#[derive(Debug)]
130pub struct AuthRequired {
131 pub description: Option<String>,
132 pub provider_id: Option<LanguageModelProviderId>,
133}
134
135impl AuthRequired {
136 pub fn new() -> Self {
137 Self {
138 description: None,
139 provider_id: None,
140 }
141 }
142
143 pub fn with_description(mut self, description: String) -> Self {
144 self.description = Some(description);
145 self
146 }
147
148 pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
149 self.provider_id = Some(provider_id);
150 self
151 }
152}
153
154impl Error for AuthRequired {}
155impl fmt::Display for AuthRequired {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 write!(f, "Authentication required")
158 }
159}
160
161/// Trait for agents that support listing, selecting, and querying language models.
162///
163/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
164pub trait AgentModelSelector: 'static {
165 /// Lists all available language models for this agent.
166 ///
167 /// # Parameters
168 /// - `cx`: The GPUI app context for async operations and global access.
169 ///
170 /// # Returns
171 /// A task resolving to the list of models or an error (e.g., if no models are configured).
172 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
173
174 /// Selects a model for a specific session (thread).
175 ///
176 /// This sets the default model for future interactions in the session.
177 /// If the session doesn't exist or the model is invalid, it returns an error.
178 ///
179 /// # Parameters
180 /// - `session_id`: The ID of the session (thread) to apply the model to.
181 /// - `model`: The model to select (should be one from [list_models]).
182 /// - `cx`: The GPUI app context.
183 ///
184 /// # Returns
185 /// A task resolving to `Ok(())` on success or an error.
186 fn select_model(
187 &self,
188 session_id: acp::SessionId,
189 model_id: AgentModelId,
190 cx: &mut App,
191 ) -> Task<Result<()>>;
192
193 /// Retrieves the currently selected model for a specific session (thread).
194 ///
195 /// # Parameters
196 /// - `session_id`: The ID of the session (thread) to query.
197 /// - `cx`: The GPUI app context.
198 ///
199 /// # Returns
200 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
201 fn selected_model(
202 &self,
203 session_id: &acp::SessionId,
204 cx: &mut App,
205 ) -> Task<Result<AgentModelInfo>>;
206
207 /// Whenever the model list is updated the receiver will be notified.
208 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
209}
210
211#[derive(Debug, Clone, PartialEq, Eq, Hash)]
212pub struct AgentModelId(pub SharedString);
213
214impl std::ops::Deref for AgentModelId {
215 type Target = SharedString;
216
217 fn deref(&self) -> &Self::Target {
218 &self.0
219 }
220}
221
222impl fmt::Display for AgentModelId {
223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224 self.0.fmt(f)
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq)]
229pub struct AgentModelInfo {
230 pub id: AgentModelId,
231 pub name: SharedString,
232 pub icon: Option<IconName>,
233}
234
235#[derive(Debug, Clone, PartialEq, Eq, Hash)]
236pub struct AgentModelGroupName(pub SharedString);
237
238#[derive(Debug, Clone)]
239pub enum AgentModelList {
240 Flat(Vec<AgentModelInfo>),
241 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
242}
243
244impl AgentModelList {
245 pub fn is_empty(&self) -> bool {
246 match self {
247 AgentModelList::Flat(models) => models.is_empty(),
248 AgentModelList::Grouped(groups) => groups.is_empty(),
249 }
250 }
251}
252
253#[cfg(feature = "test-support")]
254mod test_support {
255 use std::sync::Arc;
256
257 use action_log::ActionLog;
258 use collections::HashMap;
259 use futures::{channel::oneshot, future::try_join_all};
260 use gpui::{AppContext as _, WeakEntity};
261 use parking_lot::Mutex;
262
263 use super::*;
264
265 #[derive(Clone, Default)]
266 pub struct StubAgentConnection {
267 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
268 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
269 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
270 }
271
272 struct Session {
273 thread: WeakEntity<AcpThread>,
274 response_tx: Option<oneshot::Sender<acp::StopReason>>,
275 }
276
277 impl StubAgentConnection {
278 pub fn new() -> Self {
279 Self {
280 next_prompt_updates: Default::default(),
281 permission_requests: HashMap::default(),
282 sessions: Arc::default(),
283 }
284 }
285
286 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
287 *self.next_prompt_updates.lock() = updates;
288 }
289
290 pub fn with_permission_requests(
291 mut self,
292 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
293 ) -> Self {
294 self.permission_requests = permission_requests;
295 self
296 }
297
298 pub fn send_update(
299 &self,
300 session_id: acp::SessionId,
301 update: acp::SessionUpdate,
302 cx: &mut App,
303 ) {
304 assert!(
305 self.next_prompt_updates.lock().is_empty(),
306 "Use either send_update or set_next_prompt_updates"
307 );
308
309 self.sessions
310 .lock()
311 .get(&session_id)
312 .unwrap()
313 .thread
314 .update(cx, |thread, cx| {
315 thread.handle_session_update(update, cx).unwrap();
316 })
317 .unwrap();
318 }
319
320 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
321 self.sessions
322 .lock()
323 .get_mut(&session_id)
324 .unwrap()
325 .response_tx
326 .take()
327 .expect("No pending turn")
328 .send(stop_reason)
329 .unwrap();
330 }
331 }
332
333 impl AgentConnection for StubAgentConnection {
334 fn auth_methods(&self) -> &[acp::AuthMethod] {
335 &[]
336 }
337
338 fn new_thread(
339 self: Rc<Self>,
340 project: Entity<Project>,
341 _cwd: &Path,
342 cx: &mut gpui::App,
343 ) -> Task<gpui::Result<Entity<AcpThread>>> {
344 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
345 let action_log = cx.new(|_| ActionLog::new(project.clone()));
346 let thread = cx.new(|cx| {
347 AcpThread::new(
348 "Test",
349 self.clone(),
350 project,
351 action_log,
352 session_id.clone(),
353 watch::Receiver::constant(acp::PromptCapabilities {
354 image: true,
355 audio: true,
356 embedded_context: true,
357 meta: None,
358 }),
359 cx,
360 )
361 });
362 self.sessions.lock().insert(
363 session_id,
364 Session {
365 thread: thread.downgrade(),
366 response_tx: None,
367 },
368 );
369 Task::ready(Ok(thread))
370 }
371
372 fn authenticate(
373 &self,
374 _method_id: acp::AuthMethodId,
375 _cx: &mut App,
376 ) -> Task<gpui::Result<()>> {
377 unimplemented!()
378 }
379
380 fn prompt(
381 &self,
382 _id: Option<UserMessageId>,
383 params: acp::PromptRequest,
384 cx: &mut App,
385 ) -> Task<gpui::Result<acp::PromptResponse>> {
386 let mut sessions = self.sessions.lock();
387 let Session {
388 thread,
389 response_tx,
390 } = sessions.get_mut(¶ms.session_id).unwrap();
391 let mut tasks = vec![];
392 if self.next_prompt_updates.lock().is_empty() {
393 let (tx, rx) = oneshot::channel();
394 response_tx.replace(tx);
395 cx.spawn(async move |_| {
396 let stop_reason = rx.await?;
397 Ok(acp::PromptResponse {
398 stop_reason,
399 meta: None,
400 })
401 })
402 } else {
403 for update in self.next_prompt_updates.lock().drain(..) {
404 let thread = thread.clone();
405 let update = update.clone();
406 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
407 &update
408 && let Some(options) = self.permission_requests.get(&tool_call.id)
409 {
410 Some((tool_call.clone(), options.clone()))
411 } else {
412 None
413 };
414 let task = cx.spawn(async move |cx| {
415 if let Some((tool_call, options)) = permission_request {
416 thread
417 .update(cx, |thread, cx| {
418 thread.request_tool_call_authorization(
419 tool_call.clone().into(),
420 options.clone(),
421 false,
422 cx,
423 )
424 })??
425 .await;
426 }
427 thread.update(cx, |thread, cx| {
428 thread.handle_session_update(update.clone(), cx).unwrap();
429 })?;
430 anyhow::Ok(())
431 });
432 tasks.push(task);
433 }
434
435 cx.spawn(async move |_| {
436 try_join_all(tasks).await?;
437 Ok(acp::PromptResponse {
438 stop_reason: acp::StopReason::EndTurn,
439 meta: None,
440 })
441 })
442 }
443 }
444
445 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
446 if let Some(end_turn_tx) = self
447 .sessions
448 .lock()
449 .get_mut(session_id)
450 .unwrap()
451 .response_tx
452 .take()
453 {
454 end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
455 }
456 }
457
458 fn truncate(
459 &self,
460 _session_id: &agent_client_protocol::SessionId,
461 _cx: &App,
462 ) -> Option<Rc<dyn AgentSessionTruncate>> {
463 Some(Rc::new(StubAgentSessionEditor))
464 }
465
466 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
467 self
468 }
469 }
470
471 struct StubAgentSessionEditor;
472
473 impl AgentSessionTruncate for StubAgentSessionEditor {
474 fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
475 Task::ready(Ok(()))
476 }
477 }
478}
479
480#[cfg(feature = "test-support")]
481pub use test_support::*;