1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use language_model::LanguageModelProviderId;
7use project::Project;
8use serde::{Deserialize, Serialize};
9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
10use ui::{App, IconName};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
14pub struct UserMessageId(Arc<str>);
15
16impl UserMessageId {
17 pub fn new() -> Self {
18 Self(Uuid::new_v4().to_string().into())
19 }
20}
21
22pub trait AgentConnection {
23 fn new_thread(
24 self: Rc<Self>,
25 project: Entity<Project>,
26 cwd: &Path,
27 cx: &mut App,
28 ) -> Task<Result<Entity<AcpThread>>>;
29
30 fn auth_methods(&self) -> &[acp::AuthMethod];
31
32 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
33
34 fn prompt(
35 &self,
36 user_message_id: Option<UserMessageId>,
37 params: acp::PromptRequest,
38 cx: &mut App,
39 ) -> Task<Result<acp::PromptResponse>>;
40
41 fn resume(
42 &self,
43 _session_id: &acp::SessionId,
44 _cx: &App,
45 ) -> Option<Rc<dyn AgentSessionResume>> {
46 None
47 }
48
49 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
50
51 fn truncate(
52 &self,
53 _session_id: &acp::SessionId,
54 _cx: &App,
55 ) -> Option<Rc<dyn AgentSessionTruncate>> {
56 None
57 }
58
59 fn set_title(
60 &self,
61 _session_id: &acp::SessionId,
62 _cx: &App,
63 ) -> Option<Rc<dyn AgentSessionSetTitle>> {
64 None
65 }
66
67 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
68 ///
69 /// If the agent does not support model selection, returns [None].
70 /// This allows sharing the selector in UI components.
71 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
72 None
73 }
74
75 fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
76 None
77 }
78
79 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
80}
81
82impl dyn AgentConnection {
83 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
84 self.into_any().downcast().ok()
85 }
86}
87
88pub trait AgentSessionTruncate {
89 fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
90}
91
92pub trait AgentSessionResume {
93 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
94}
95
96pub trait AgentSessionSetTitle {
97 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
98}
99
100pub trait AgentTelemetry {
101 /// The name of the agent used for telemetry.
102 fn agent_name(&self) -> String;
103
104 /// A representation of the current thread state that can be serialized for
105 /// storage with telemetry events.
106 fn thread_data(
107 &self,
108 session_id: &acp::SessionId,
109 cx: &mut App,
110 ) -> Task<Result<serde_json::Value>>;
111}
112
113#[derive(Debug)]
114pub struct AuthRequired {
115 pub description: Option<String>,
116 pub provider_id: Option<LanguageModelProviderId>,
117}
118
119impl AuthRequired {
120 pub fn new() -> Self {
121 Self {
122 description: None,
123 provider_id: None,
124 }
125 }
126
127 pub fn with_description(mut self, description: String) -> Self {
128 self.description = Some(description);
129 self
130 }
131
132 pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
133 self.provider_id = Some(provider_id);
134 self
135 }
136}
137
138impl Error for AuthRequired {}
139impl fmt::Display for AuthRequired {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 write!(f, "Authentication required")
142 }
143}
144
145/// Trait for agents that support listing, selecting, and querying language models.
146///
147/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
148pub trait AgentModelSelector: 'static {
149 /// Lists all available language models for this agent.
150 ///
151 /// # Parameters
152 /// - `cx`: The GPUI app context for async operations and global access.
153 ///
154 /// # Returns
155 /// A task resolving to the list of models or an error (e.g., if no models are configured).
156 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
157
158 /// Selects a model for a specific session (thread).
159 ///
160 /// This sets the default model for future interactions in the session.
161 /// If the session doesn't exist or the model is invalid, it returns an error.
162 ///
163 /// # Parameters
164 /// - `session_id`: The ID of the session (thread) to apply the model to.
165 /// - `model`: The model to select (should be one from [list_models]).
166 /// - `cx`: The GPUI app context.
167 ///
168 /// # Returns
169 /// A task resolving to `Ok(())` on success or an error.
170 fn select_model(
171 &self,
172 session_id: acp::SessionId,
173 model_id: AgentModelId,
174 cx: &mut App,
175 ) -> Task<Result<()>>;
176
177 /// Retrieves the currently selected model for a specific session (thread).
178 ///
179 /// # Parameters
180 /// - `session_id`: The ID of the session (thread) to query.
181 /// - `cx`: The GPUI app context.
182 ///
183 /// # Returns
184 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
185 fn selected_model(
186 &self,
187 session_id: &acp::SessionId,
188 cx: &mut App,
189 ) -> Task<Result<AgentModelInfo>>;
190
191 /// Whenever the model list is updated the receiver will be notified.
192 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
193}
194
195#[derive(Debug, Clone, PartialEq, Eq, Hash)]
196pub struct AgentModelId(pub SharedString);
197
198impl std::ops::Deref for AgentModelId {
199 type Target = SharedString;
200
201 fn deref(&self) -> &Self::Target {
202 &self.0
203 }
204}
205
206impl fmt::Display for AgentModelId {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 self.0.fmt(f)
209 }
210}
211
212#[derive(Debug, Clone, PartialEq, Eq)]
213pub struct AgentModelInfo {
214 pub id: AgentModelId,
215 pub name: SharedString,
216 pub icon: Option<IconName>,
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Hash)]
220pub struct AgentModelGroupName(pub SharedString);
221
222#[derive(Debug, Clone)]
223pub enum AgentModelList {
224 Flat(Vec<AgentModelInfo>),
225 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
226}
227
228impl AgentModelList {
229 pub fn is_empty(&self) -> bool {
230 match self {
231 AgentModelList::Flat(models) => models.is_empty(),
232 AgentModelList::Grouped(groups) => groups.is_empty(),
233 }
234 }
235
236 pub fn len(&self) -> usize {
237 match self {
238 AgentModelList::Flat(models) => models.len(),
239 AgentModelList::Grouped(groups) => groups.values().len(),
240 }
241 }
242}
243
244#[cfg(feature = "test-support")]
245mod test_support {
246 use std::sync::Arc;
247
248 use action_log::ActionLog;
249 use collections::HashMap;
250 use futures::{channel::oneshot, future::try_join_all};
251 use gpui::{AppContext as _, WeakEntity};
252 use parking_lot::Mutex;
253
254 use super::*;
255
256 #[derive(Clone, Default)]
257 pub struct StubAgentConnection {
258 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
259 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
260 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
261 }
262
263 struct Session {
264 thread: WeakEntity<AcpThread>,
265 response_tx: Option<oneshot::Sender<acp::StopReason>>,
266 }
267
268 impl StubAgentConnection {
269 pub fn new() -> Self {
270 Self {
271 next_prompt_updates: Default::default(),
272 permission_requests: HashMap::default(),
273 sessions: Arc::default(),
274 }
275 }
276
277 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
278 *self.next_prompt_updates.lock() = updates;
279 }
280
281 pub fn with_permission_requests(
282 mut self,
283 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
284 ) -> Self {
285 self.permission_requests = permission_requests;
286 self
287 }
288
289 pub fn send_update(
290 &self,
291 session_id: acp::SessionId,
292 update: acp::SessionUpdate,
293 cx: &mut App,
294 ) {
295 assert!(
296 self.next_prompt_updates.lock().is_empty(),
297 "Use either send_update or set_next_prompt_updates"
298 );
299
300 self.sessions
301 .lock()
302 .get(&session_id)
303 .unwrap()
304 .thread
305 .update(cx, |thread, cx| {
306 thread.handle_session_update(update, cx).unwrap();
307 })
308 .unwrap();
309 }
310
311 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
312 self.sessions
313 .lock()
314 .get_mut(&session_id)
315 .unwrap()
316 .response_tx
317 .take()
318 .expect("No pending turn")
319 .send(stop_reason)
320 .unwrap();
321 }
322 }
323
324 impl AgentConnection for StubAgentConnection {
325 fn auth_methods(&self) -> &[acp::AuthMethod] {
326 &[]
327 }
328
329 fn new_thread(
330 self: Rc<Self>,
331 project: Entity<Project>,
332 _cwd: &Path,
333 cx: &mut gpui::App,
334 ) -> Task<gpui::Result<Entity<AcpThread>>> {
335 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
336 let action_log = cx.new(|_| ActionLog::new(project.clone()));
337 let thread = cx.new(|cx| {
338 AcpThread::new(
339 "Test",
340 self.clone(),
341 project,
342 action_log,
343 session_id.clone(),
344 watch::Receiver::constant(acp::PromptCapabilities {
345 image: true,
346 audio: true,
347 embedded_context: true,
348 }),
349 cx,
350 )
351 });
352 self.sessions.lock().insert(
353 session_id,
354 Session {
355 thread: thread.downgrade(),
356 response_tx: None,
357 },
358 );
359 Task::ready(Ok(thread))
360 }
361
362 fn authenticate(
363 &self,
364 _method_id: acp::AuthMethodId,
365 _cx: &mut App,
366 ) -> Task<gpui::Result<()>> {
367 unimplemented!()
368 }
369
370 fn prompt(
371 &self,
372 _id: Option<UserMessageId>,
373 params: acp::PromptRequest,
374 cx: &mut App,
375 ) -> Task<gpui::Result<acp::PromptResponse>> {
376 let mut sessions = self.sessions.lock();
377 let Session {
378 thread,
379 response_tx,
380 } = sessions.get_mut(¶ms.session_id).unwrap();
381 let mut tasks = vec![];
382 if self.next_prompt_updates.lock().is_empty() {
383 let (tx, rx) = oneshot::channel();
384 response_tx.replace(tx);
385 cx.spawn(async move |_| {
386 let stop_reason = rx.await?;
387 Ok(acp::PromptResponse { stop_reason })
388 })
389 } else {
390 for update in self.next_prompt_updates.lock().drain(..) {
391 let thread = thread.clone();
392 let update = update.clone();
393 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
394 &update
395 && let Some(options) = self.permission_requests.get(&tool_call.id)
396 {
397 Some((tool_call.clone(), options.clone()))
398 } else {
399 None
400 };
401 let task = cx.spawn(async move |cx| {
402 if let Some((tool_call, options)) = permission_request {
403 let permission = thread.update(cx, |thread, cx| {
404 thread.request_tool_call_authorization(
405 tool_call.clone().into(),
406 options.clone(),
407 cx,
408 )
409 })?;
410 permission?.await?;
411 }
412 thread.update(cx, |thread, cx| {
413 thread.handle_session_update(update.clone(), cx).unwrap();
414 })?;
415 anyhow::Ok(())
416 });
417 tasks.push(task);
418 }
419
420 cx.spawn(async move |_| {
421 try_join_all(tasks).await?;
422 Ok(acp::PromptResponse {
423 stop_reason: acp::StopReason::EndTurn,
424 })
425 })
426 }
427 }
428
429 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
430 if let Some(end_turn_tx) = self
431 .sessions
432 .lock()
433 .get_mut(session_id)
434 .unwrap()
435 .response_tx
436 .take()
437 {
438 end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
439 }
440 }
441
442 fn truncate(
443 &self,
444 _session_id: &agent_client_protocol::SessionId,
445 _cx: &App,
446 ) -> Option<Rc<dyn AgentSessionTruncate>> {
447 Some(Rc::new(StubAgentSessionEditor))
448 }
449
450 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
451 self
452 }
453 }
454
455 struct StubAgentSessionEditor;
456
457 impl AgentSessionTruncate for StubAgentSessionEditor {
458 fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
459 Task::ready(Ok(()))
460 }
461 }
462}
463
464#[cfg(feature = "test-support")]
465pub use test_support::*;