1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use language_model::LanguageModelProviderId;
7use project::Project;
8use serde::{Deserialize, Serialize};
9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
10use ui::{App, IconName};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
14pub struct UserMessageId(Arc<str>);
15
16impl UserMessageId {
17 pub fn new() -> Self {
18 Self(Uuid::new_v4().to_string().into())
19 }
20}
21
22pub trait AgentConnection {
23 fn new_thread(
24 self: Rc<Self>,
25 project: Entity<Project>,
26 cwd: &Path,
27 cx: &mut App,
28 ) -> Task<Result<Entity<AcpThread>>>;
29
30 fn auth_methods(&self) -> &[acp::AuthMethod];
31
32 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
33
34 fn prompt(
35 &self,
36 user_message_id: Option<UserMessageId>,
37 params: acp::PromptRequest,
38 cx: &mut App,
39 ) -> Task<Result<acp::PromptResponse>>;
40
41 fn resume(
42 &self,
43 _session_id: &acp::SessionId,
44 _cx: &App,
45 ) -> Option<Rc<dyn AgentSessionResume>> {
46 None
47 }
48
49 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
50
51 fn truncate(
52 &self,
53 _session_id: &acp::SessionId,
54 _cx: &App,
55 ) -> Option<Rc<dyn AgentSessionTruncate>> {
56 None
57 }
58
59 fn set_title(
60 &self,
61 _session_id: &acp::SessionId,
62 _cx: &App,
63 ) -> Option<Rc<dyn AgentSessionSetTitle>> {
64 None
65 }
66
67 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
68 ///
69 /// If the agent does not support model selection, returns [None].
70 /// This allows sharing the selector in UI components.
71 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
72 None
73 }
74
75 fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
76 None
77 }
78 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
79}
80
81impl dyn AgentConnection {
82 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
83 self.into_any().downcast().ok()
84 }
85}
86
87pub trait AgentSessionTruncate {
88 fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
89}
90
91pub trait AgentSessionResume {
92 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
93}
94
95pub trait AgentSessionSetTitle {
96 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
97}
98
99pub trait AgentTelemetry {
100 /// The name of the agent used for telemetry.
101 fn agent_name(&self) -> String;
102
103 /// A representation of the current thread state that can be serialized for
104 /// storage with telemetry events.
105 fn thread_data(
106 &self,
107 session_id: &acp::SessionId,
108 cx: &mut App,
109 ) -> Task<Result<serde_json::Value>>;
110}
111
112#[derive(Debug)]
113pub struct AuthRequired {
114 pub description: Option<String>,
115 pub provider_id: Option<LanguageModelProviderId>,
116}
117
118impl AuthRequired {
119 pub fn new() -> Self {
120 Self {
121 description: None,
122 provider_id: None,
123 }
124 }
125
126 pub fn with_description(mut self, description: String) -> Self {
127 self.description = Some(description);
128 self
129 }
130
131 pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
132 self.provider_id = Some(provider_id);
133 self
134 }
135}
136
137impl Error for AuthRequired {}
138impl fmt::Display for AuthRequired {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 write!(f, "Authentication required")
141 }
142}
143
144/// Trait for agents that support listing, selecting, and querying language models.
145///
146/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
147pub trait AgentModelSelector: 'static {
148 /// Lists all available language models for this agent.
149 ///
150 /// # Parameters
151 /// - `cx`: The GPUI app context for async operations and global access.
152 ///
153 /// # Returns
154 /// A task resolving to the list of models or an error (e.g., if no models are configured).
155 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
156
157 /// Selects a model for a specific session (thread).
158 ///
159 /// This sets the default model for future interactions in the session.
160 /// If the session doesn't exist or the model is invalid, it returns an error.
161 ///
162 /// # Parameters
163 /// - `session_id`: The ID of the session (thread) to apply the model to.
164 /// - `model`: The model to select (should be one from [list_models]).
165 /// - `cx`: The GPUI app context.
166 ///
167 /// # Returns
168 /// A task resolving to `Ok(())` on success or an error.
169 fn select_model(
170 &self,
171 session_id: acp::SessionId,
172 model_id: AgentModelId,
173 cx: &mut App,
174 ) -> Task<Result<()>>;
175
176 /// Retrieves the currently selected model for a specific session (thread).
177 ///
178 /// # Parameters
179 /// - `session_id`: The ID of the session (thread) to query.
180 /// - `cx`: The GPUI app context.
181 ///
182 /// # Returns
183 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
184 fn selected_model(
185 &self,
186 session_id: &acp::SessionId,
187 cx: &mut App,
188 ) -> Task<Result<AgentModelInfo>>;
189
190 /// Whenever the model list is updated the receiver will be notified.
191 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, Hash)]
195pub struct AgentModelId(pub SharedString);
196
197impl std::ops::Deref for AgentModelId {
198 type Target = SharedString;
199
200 fn deref(&self) -> &Self::Target {
201 &self.0
202 }
203}
204
205impl fmt::Display for AgentModelId {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 self.0.fmt(f)
208 }
209}
210
211#[derive(Debug, Clone, PartialEq, Eq)]
212pub struct AgentModelInfo {
213 pub id: AgentModelId,
214 pub name: SharedString,
215 pub icon: Option<IconName>,
216}
217
218#[derive(Debug, Clone, PartialEq, Eq, Hash)]
219pub struct AgentModelGroupName(pub SharedString);
220
221#[derive(Debug, Clone)]
222pub enum AgentModelList {
223 Flat(Vec<AgentModelInfo>),
224 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
225}
226
227impl AgentModelList {
228 pub fn is_empty(&self) -> bool {
229 match self {
230 AgentModelList::Flat(models) => models.is_empty(),
231 AgentModelList::Grouped(groups) => groups.is_empty(),
232 }
233 }
234}
235
236#[cfg(feature = "test-support")]
237mod test_support {
238 use std::sync::Arc;
239
240 use action_log::ActionLog;
241 use collections::HashMap;
242 use futures::{channel::oneshot, future::try_join_all};
243 use gpui::{AppContext as _, WeakEntity};
244 use parking_lot::Mutex;
245
246 use super::*;
247
248 #[derive(Clone, Default)]
249 pub struct StubAgentConnection {
250 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
251 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
252 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
253 }
254
255 struct Session {
256 thread: WeakEntity<AcpThread>,
257 response_tx: Option<oneshot::Sender<acp::StopReason>>,
258 }
259
260 impl StubAgentConnection {
261 pub fn new() -> Self {
262 Self {
263 next_prompt_updates: Default::default(),
264 permission_requests: HashMap::default(),
265 sessions: Arc::default(),
266 }
267 }
268
269 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
270 *self.next_prompt_updates.lock() = updates;
271 }
272
273 pub fn with_permission_requests(
274 mut self,
275 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
276 ) -> Self {
277 self.permission_requests = permission_requests;
278 self
279 }
280
281 pub fn send_update(
282 &self,
283 session_id: acp::SessionId,
284 update: acp::SessionUpdate,
285 cx: &mut App,
286 ) {
287 assert!(
288 self.next_prompt_updates.lock().is_empty(),
289 "Use either send_update or set_next_prompt_updates"
290 );
291
292 self.sessions
293 .lock()
294 .get(&session_id)
295 .unwrap()
296 .thread
297 .update(cx, |thread, cx| {
298 thread.handle_session_update(update, cx).unwrap();
299 })
300 .unwrap();
301 }
302
303 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
304 self.sessions
305 .lock()
306 .get_mut(&session_id)
307 .unwrap()
308 .response_tx
309 .take()
310 .expect("No pending turn")
311 .send(stop_reason)
312 .unwrap();
313 }
314 }
315
316 impl AgentConnection for StubAgentConnection {
317 fn auth_methods(&self) -> &[acp::AuthMethod] {
318 &[]
319 }
320
321 fn new_thread(
322 self: Rc<Self>,
323 project: Entity<Project>,
324 _cwd: &Path,
325 cx: &mut gpui::App,
326 ) -> Task<gpui::Result<Entity<AcpThread>>> {
327 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
328 let action_log = cx.new(|_| ActionLog::new(project.clone()));
329 let thread = cx.new(|cx| {
330 AcpThread::new(
331 "Test",
332 self.clone(),
333 project,
334 action_log,
335 session_id.clone(),
336 watch::Receiver::constant(acp::PromptCapabilities {
337 image: true,
338 audio: true,
339 embedded_context: true,
340 }),
341 cx,
342 )
343 });
344 self.sessions.lock().insert(
345 session_id,
346 Session {
347 thread: thread.downgrade(),
348 response_tx: None,
349 },
350 );
351 Task::ready(Ok(thread))
352 }
353
354 fn authenticate(
355 &self,
356 _method_id: acp::AuthMethodId,
357 _cx: &mut App,
358 ) -> Task<gpui::Result<()>> {
359 unimplemented!()
360 }
361
362 fn prompt(
363 &self,
364 _id: Option<UserMessageId>,
365 params: acp::PromptRequest,
366 cx: &mut App,
367 ) -> Task<gpui::Result<acp::PromptResponse>> {
368 let mut sessions = self.sessions.lock();
369 let Session {
370 thread,
371 response_tx,
372 } = sessions.get_mut(¶ms.session_id).unwrap();
373 let mut tasks = vec![];
374 if self.next_prompt_updates.lock().is_empty() {
375 let (tx, rx) = oneshot::channel();
376 response_tx.replace(tx);
377 cx.spawn(async move |_| {
378 let stop_reason = rx.await?;
379 Ok(acp::PromptResponse { stop_reason })
380 })
381 } else {
382 for update in self.next_prompt_updates.lock().drain(..) {
383 let thread = thread.clone();
384 let update = update.clone();
385 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
386 &update
387 && let Some(options) = self.permission_requests.get(&tool_call.id)
388 {
389 Some((tool_call.clone(), options.clone()))
390 } else {
391 None
392 };
393 let task = cx.spawn(async move |cx| {
394 if let Some((tool_call, options)) = permission_request {
395 thread
396 .update(cx, |thread, cx| {
397 thread.request_tool_call_authorization(
398 tool_call.clone().into(),
399 options.clone(),
400 cx,
401 )
402 })??
403 .await;
404 }
405 thread.update(cx, |thread, cx| {
406 thread.handle_session_update(update.clone(), cx).unwrap();
407 })?;
408 anyhow::Ok(())
409 });
410 tasks.push(task);
411 }
412
413 cx.spawn(async move |_| {
414 try_join_all(tasks).await?;
415 Ok(acp::PromptResponse {
416 stop_reason: acp::StopReason::EndTurn,
417 })
418 })
419 }
420 }
421
422 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
423 if let Some(end_turn_tx) = self
424 .sessions
425 .lock()
426 .get_mut(session_id)
427 .unwrap()
428 .response_tx
429 .take()
430 {
431 end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
432 }
433 }
434
435 fn truncate(
436 &self,
437 _session_id: &agent_client_protocol::SessionId,
438 _cx: &App,
439 ) -> Option<Rc<dyn AgentSessionTruncate>> {
440 Some(Rc::new(StubAgentSessionEditor))
441 }
442
443 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
444 self
445 }
446 }
447
448 struct StubAgentSessionEditor;
449
450 impl AgentSessionTruncate for StubAgentSessionEditor {
451 fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
452 Task::ready(Ok(()))
453 }
454 }
455}
456
457#[cfg(feature = "test-support")]
458pub use test_support::*;