1use crate::AcpThread;
2use agent_client_protocol::{self as acp};
3use anyhow::Result;
4use collections::IndexMap;
5use gpui::{Entity, SharedString, Task};
6use language_model::LanguageModelProviderId;
7use project::Project;
8use std::{any::Any, error::Error, fmt, path::Path, rc::Rc};
9use ui::{App, IconName};
10
11pub trait AgentConnection {
12 fn new_thread(
13 self: Rc<Self>,
14 project: Entity<Project>,
15 cwd: &Path,
16 cx: &mut App,
17 ) -> Task<Result<Entity<AcpThread>>>;
18
19 fn auth_methods(&self) -> &[acp::AuthMethod];
20
21 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
22
23 fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
24 -> Task<Result<acp::PromptResponse>>;
25
26 fn resume(
27 &self,
28 _session_id: &acp::SessionId,
29 _cx: &App,
30 ) -> Option<Rc<dyn AgentSessionResume>> {
31 None
32 }
33
34 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
35
36 fn rewind(
37 &self,
38 _session_id: &acp::SessionId,
39 _cx: &App,
40 ) -> Option<Rc<dyn AgentSessionRewind>> {
41 None
42 }
43
44 fn set_title(
45 &self,
46 _session_id: &acp::SessionId,
47 _cx: &App,
48 ) -> Option<Rc<dyn AgentSessionSetTitle>> {
49 None
50 }
51
52 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
53 ///
54 /// If the agent does not support model selection, returns [None].
55 /// This allows sharing the selector in UI components.
56 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
57 None
58 }
59
60 fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
61 None
62 }
63
64 fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
65}
66
67impl dyn AgentConnection {
68 pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
69 self.into_any().downcast().ok()
70 }
71}
72
73pub trait AgentSessionRewind {
74 fn rewind(&self, message_id: acp::PromptId, cx: &mut App) -> Task<Result<()>>;
75}
76
77pub trait AgentSessionResume {
78 fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
79}
80
81pub trait AgentSessionSetTitle {
82 fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
83}
84
85pub trait AgentTelemetry {
86 /// The name of the agent used for telemetry.
87 fn agent_name(&self) -> String;
88
89 /// A representation of the current thread state that can be serialized for
90 /// storage with telemetry events.
91 fn thread_data(
92 &self,
93 session_id: &acp::SessionId,
94 cx: &mut App,
95 ) -> Task<Result<serde_json::Value>>;
96}
97
98#[derive(Debug)]
99pub struct AuthRequired {
100 pub description: Option<String>,
101 pub provider_id: Option<LanguageModelProviderId>,
102}
103
104impl AuthRequired {
105 pub fn new() -> Self {
106 Self {
107 description: None,
108 provider_id: None,
109 }
110 }
111
112 pub fn with_description(mut self, description: String) -> Self {
113 self.description = Some(description);
114 self
115 }
116
117 pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
118 self.provider_id = Some(provider_id);
119 self
120 }
121}
122
123impl Error for AuthRequired {}
124impl fmt::Display for AuthRequired {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 write!(f, "Authentication required")
127 }
128}
129
130/// Trait for agents that support listing, selecting, and querying language models.
131///
132/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
133pub trait AgentModelSelector: 'static {
134 /// Lists all available language models for this agent.
135 ///
136 /// # Parameters
137 /// - `cx`: The GPUI app context for async operations and global access.
138 ///
139 /// # Returns
140 /// A task resolving to the list of models or an error (e.g., if no models are configured).
141 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
142
143 /// Selects a model for a specific session (thread).
144 ///
145 /// This sets the default model for future interactions in the session.
146 /// If the session doesn't exist or the model is invalid, it returns an error.
147 ///
148 /// # Parameters
149 /// - `session_id`: The ID of the session (thread) to apply the model to.
150 /// - `model`: The model to select (should be one from [list_models]).
151 /// - `cx`: The GPUI app context.
152 ///
153 /// # Returns
154 /// A task resolving to `Ok(())` on success or an error.
155 fn select_model(
156 &self,
157 session_id: acp::SessionId,
158 model_id: AgentModelId,
159 cx: &mut App,
160 ) -> Task<Result<()>>;
161
162 /// Retrieves the currently selected model for a specific session (thread).
163 ///
164 /// # Parameters
165 /// - `session_id`: The ID of the session (thread) to query.
166 /// - `cx`: The GPUI app context.
167 ///
168 /// # Returns
169 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
170 fn selected_model(
171 &self,
172 session_id: &acp::SessionId,
173 cx: &mut App,
174 ) -> Task<Result<AgentModelInfo>>;
175
176 /// Whenever the model list is updated the receiver will be notified.
177 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
178}
179
180#[derive(Debug, Clone, PartialEq, Eq, Hash)]
181pub struct AgentModelId(pub SharedString);
182
183impl std::ops::Deref for AgentModelId {
184 type Target = SharedString;
185
186 fn deref(&self) -> &Self::Target {
187 &self.0
188 }
189}
190
191impl fmt::Display for AgentModelId {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 self.0.fmt(f)
194 }
195}
196
197#[derive(Debug, Clone, PartialEq, Eq)]
198pub struct AgentModelInfo {
199 pub id: AgentModelId,
200 pub name: SharedString,
201 pub icon: Option<IconName>,
202}
203
204#[derive(Debug, Clone, PartialEq, Eq, Hash)]
205pub struct AgentModelGroupName(pub SharedString);
206
207#[derive(Debug, Clone)]
208pub enum AgentModelList {
209 Flat(Vec<AgentModelInfo>),
210 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
211}
212
213impl AgentModelList {
214 pub fn is_empty(&self) -> bool {
215 match self {
216 AgentModelList::Flat(models) => models.is_empty(),
217 AgentModelList::Grouped(groups) => groups.is_empty(),
218 }
219 }
220}
221
222#[cfg(feature = "test-support")]
223mod test_support {
224 use std::sync::Arc;
225
226 use action_log::ActionLog;
227 use collections::HashMap;
228 use futures::{channel::oneshot, future::try_join_all};
229 use gpui::{AppContext as _, WeakEntity};
230 use parking_lot::Mutex;
231
232 use super::*;
233
234 #[derive(Clone, Default)]
235 pub struct StubAgentConnection {
236 sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
237 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
238 next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
239 }
240
241 struct Session {
242 thread: WeakEntity<AcpThread>,
243 response_tx: Option<oneshot::Sender<acp::StopReason>>,
244 }
245
246 impl StubAgentConnection {
247 pub fn new() -> Self {
248 Self {
249 next_prompt_updates: Default::default(),
250 permission_requests: HashMap::default(),
251 sessions: Arc::default(),
252 }
253 }
254
255 pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
256 *self.next_prompt_updates.lock() = updates;
257 }
258
259 pub fn with_permission_requests(
260 mut self,
261 permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
262 ) -> Self {
263 self.permission_requests = permission_requests;
264 self
265 }
266
267 pub fn send_update(
268 &self,
269 session_id: acp::SessionId,
270 update: acp::SessionUpdate,
271 cx: &mut App,
272 ) {
273 assert!(
274 self.next_prompt_updates.lock().is_empty(),
275 "Use either send_update or set_next_prompt_updates"
276 );
277
278 self.sessions
279 .lock()
280 .get(&session_id)
281 .unwrap()
282 .thread
283 .update(cx, |thread, cx| {
284 thread.handle_session_update(update, cx).unwrap();
285 })
286 .unwrap();
287 }
288
289 pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
290 self.sessions
291 .lock()
292 .get_mut(&session_id)
293 .unwrap()
294 .response_tx
295 .take()
296 .expect("No pending turn")
297 .send(stop_reason)
298 .unwrap();
299 }
300 }
301
302 impl AgentConnection for StubAgentConnection {
303 fn auth_methods(&self) -> &[acp::AuthMethod] {
304 &[]
305 }
306
307 fn new_thread(
308 self: Rc<Self>,
309 project: Entity<Project>,
310 _cwd: &Path,
311 cx: &mut gpui::App,
312 ) -> Task<gpui::Result<Entity<AcpThread>>> {
313 let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
314 let action_log = cx.new(|_| ActionLog::new(project.clone()));
315 let thread = cx.new(|cx| {
316 AcpThread::new(
317 "Test",
318 self.clone(),
319 project,
320 action_log,
321 session_id.clone(),
322 watch::Receiver::constant(acp::PromptCapabilities {
323 image: true,
324 audio: true,
325 embedded_context: true,
326 }),
327 cx,
328 )
329 });
330 self.sessions.lock().insert(
331 session_id,
332 Session {
333 thread: thread.downgrade(),
334 response_tx: None,
335 },
336 );
337 Task::ready(Ok(thread))
338 }
339
340 fn authenticate(
341 &self,
342 _method_id: acp::AuthMethodId,
343 _cx: &mut App,
344 ) -> Task<gpui::Result<()>> {
345 unimplemented!()
346 }
347
348 fn prompt(
349 &self,
350 params: acp::PromptRequest,
351 cx: &mut App,
352 ) -> Task<gpui::Result<acp::PromptResponse>> {
353 let mut sessions = self.sessions.lock();
354 let Session {
355 thread,
356 response_tx,
357 } = sessions.get_mut(¶ms.session_id).unwrap();
358 let mut tasks = vec![];
359 if self.next_prompt_updates.lock().is_empty() {
360 let (tx, rx) = oneshot::channel();
361 response_tx.replace(tx);
362 cx.spawn(async move |_| {
363 let stop_reason = rx.await?;
364 Ok(acp::PromptResponse { stop_reason })
365 })
366 } else {
367 for update in self.next_prompt_updates.lock().drain(..) {
368 let thread = thread.clone();
369 let update = update.clone();
370 let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
371 &update
372 && let Some(options) = self.permission_requests.get(&tool_call.id)
373 {
374 Some((tool_call.clone(), options.clone()))
375 } else {
376 None
377 };
378 let task = cx.spawn(async move |cx| {
379 if let Some((tool_call, options)) = permission_request {
380 let permission = thread.update(cx, |thread, cx| {
381 thread.request_tool_call_authorization(
382 tool_call.clone().into(),
383 options.clone(),
384 cx,
385 )
386 })?;
387 permission?.await?;
388 }
389 thread.update(cx, |thread, cx| {
390 thread.handle_session_update(update.clone(), cx).unwrap();
391 })?;
392 anyhow::Ok(())
393 });
394 tasks.push(task);
395 }
396
397 cx.spawn(async move |_| {
398 try_join_all(tasks).await?;
399 Ok(acp::PromptResponse {
400 stop_reason: acp::StopReason::EndTurn,
401 })
402 })
403 }
404 }
405
406 fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
407 if let Some(end_turn_tx) = self
408 .sessions
409 .lock()
410 .get_mut(session_id)
411 .unwrap()
412 .response_tx
413 .take()
414 {
415 end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
416 }
417 }
418
419 fn rewind(
420 &self,
421 _session_id: &agent_client_protocol::SessionId,
422 _cx: &App,
423 ) -> Option<Rc<dyn AgentSessionRewind>> {
424 Some(Rc::new(StubAgentSessionEditor))
425 }
426
427 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
428 self
429 }
430 }
431
432 struct StubAgentSessionEditor;
433
434 impl AgentSessionRewind for StubAgentSessionEditor {
435 fn rewind(&self, _: acp::PromptId, _: &mut App) -> Task<Result<()>> {
436 Task::ready(Ok(()))
437 }
438 }
439}
440
441#[cfg(feature = "test-support")]
442pub use test_support::*;