1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn auth_methods(&self) -> &[acp::AuthMethod];
 31
 32    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 33
 34    fn prompt(
 35        &self,
 36        user_message_id: Option<UserMessageId>,
 37        params: acp::PromptRequest,
 38        cx: &mut App,
 39    ) -> Task<Result<acp::PromptResponse>>;
 40
 41    fn resume(
 42        &self,
 43        _session_id: &acp::SessionId,
 44        _cx: &App,
 45    ) -> Option<Rc<dyn AgentSessionResume>> {
 46        None
 47    }
 48
 49    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 50
 51    fn truncate(
 52        &self,
 53        _session_id: &acp::SessionId,
 54        _cx: &App,
 55    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 56        None
 57    }
 58
 59    fn set_title(
 60        &self,
 61        _session_id: &acp::SessionId,
 62        _cx: &App,
 63    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 64        None
 65    }
 66
 67    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 68    ///
 69    /// If the agent does not support model selection, returns [None].
 70    /// This allows sharing the selector in UI components.
 71    fn model_selector(&self, _session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 72        None
 73    }
 74
 75    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 76        None
 77    }
 78
 79    fn session_modes(
 80        &self,
 81        _session_id: &acp::SessionId,
 82        _cx: &App,
 83    ) -> Option<Rc<dyn AgentSessionModes>> {
 84        None
 85    }
 86
 87    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 88}
 89
 90impl dyn AgentConnection {
 91    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 92        self.into_any().downcast().ok()
 93    }
 94}
 95
 96pub trait AgentSessionTruncate {
 97    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 98}
 99
100pub trait AgentSessionResume {
101    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
102}
103
104pub trait AgentSessionSetTitle {
105    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
106}
107
108pub trait AgentTelemetry {
109    /// The name of the agent used for telemetry.
110    fn agent_name(&self) -> String;
111
112    /// A representation of the current thread state that can be serialized for
113    /// storage with telemetry events.
114    fn thread_data(
115        &self,
116        session_id: &acp::SessionId,
117        cx: &mut App,
118    ) -> Task<Result<serde_json::Value>>;
119}
120
121pub trait AgentSessionModes {
122    fn current_mode(&self) -> acp::SessionModeId;
123
124    fn all_modes(&self) -> Vec<acp::SessionMode>;
125
126    fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
127}
128
129#[derive(Debug)]
130pub struct AuthRequired {
131    pub description: Option<String>,
132    pub provider_id: Option<LanguageModelProviderId>,
133}
134
135impl AuthRequired {
136    pub fn new() -> Self {
137        Self {
138            description: None,
139            provider_id: None,
140        }
141    }
142
143    pub fn with_description(mut self, description: String) -> Self {
144        self.description = Some(description);
145        self
146    }
147
148    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
149        self.provider_id = Some(provider_id);
150        self
151    }
152}
153
154impl Error for AuthRequired {}
155impl fmt::Display for AuthRequired {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        write!(f, "Authentication required")
158    }
159}
160
161/// Trait for agents that support listing, selecting, and querying language models.
162///
163/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
164pub trait AgentModelSelector: 'static {
165    /// Lists all available language models for this agent.
166    ///
167    /// # Parameters
168    /// - `cx`: The GPUI app context for async operations and global access.
169    ///
170    /// # Returns
171    /// A task resolving to the list of models or an error (e.g., if no models are configured).
172    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
173
174    /// Selects a model for a specific session (thread).
175    ///
176    /// This sets the default model for future interactions in the session.
177    /// If the session doesn't exist or the model is invalid, it returns an error.
178    ///
179    /// # Parameters
180    /// - `model`: The model to select (should be one from [list_models]).
181    /// - `cx`: The GPUI app context.
182    ///
183    /// # Returns
184    /// A task resolving to `Ok(())` on success or an error.
185    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>>;
186
187    /// Retrieves the currently selected model for a specific session (thread).
188    ///
189    /// # Parameters
190    /// - `cx`: The GPUI app context.
191    ///
192    /// # Returns
193    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
194    fn selected_model(&self, cx: &mut App) -> Task<Result<AgentModelInfo>>;
195
196    /// Whenever the model list is updated the receiver will be notified.
197    /// Optional for agents that don't update their model list.
198    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
199        None
200    }
201}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
204pub struct AgentModelInfo {
205    pub id: acp::ModelId,
206    pub name: SharedString,
207    pub description: Option<SharedString>,
208    pub icon: Option<IconName>,
209}
210
211impl From<acp::ModelInfo> for AgentModelInfo {
212    fn from(info: acp::ModelInfo) -> Self {
213        Self {
214            id: info.model_id,
215            name: info.name.into(),
216            description: info.description.map(|desc| desc.into()),
217            icon: None,
218        }
219    }
220}
221
222#[derive(Debug, Clone, PartialEq, Eq, Hash)]
223pub struct AgentModelGroupName(pub SharedString);
224
225#[derive(Debug, Clone)]
226pub enum AgentModelList {
227    Flat(Vec<AgentModelInfo>),
228    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
229}
230
231impl AgentModelList {
232    pub fn is_empty(&self) -> bool {
233        match self {
234            AgentModelList::Flat(models) => models.is_empty(),
235            AgentModelList::Grouped(groups) => groups.is_empty(),
236        }
237    }
238}
239
240#[cfg(feature = "test-support")]
241mod test_support {
242    use std::sync::Arc;
243
244    use action_log::ActionLog;
245    use collections::HashMap;
246    use futures::{channel::oneshot, future::try_join_all};
247    use gpui::{AppContext as _, WeakEntity};
248    use parking_lot::Mutex;
249
250    use super::*;
251
252    #[derive(Clone, Default)]
253    pub struct StubAgentConnection {
254        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
255        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
256        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
257    }
258
259    struct Session {
260        thread: WeakEntity<AcpThread>,
261        response_tx: Option<oneshot::Sender<acp::StopReason>>,
262    }
263
264    impl StubAgentConnection {
265        pub fn new() -> Self {
266            Self {
267                next_prompt_updates: Default::default(),
268                permission_requests: HashMap::default(),
269                sessions: Arc::default(),
270            }
271        }
272
273        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
274            *self.next_prompt_updates.lock() = updates;
275        }
276
277        pub fn with_permission_requests(
278            mut self,
279            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
280        ) -> Self {
281            self.permission_requests = permission_requests;
282            self
283        }
284
285        pub fn send_update(
286            &self,
287            session_id: acp::SessionId,
288            update: acp::SessionUpdate,
289            cx: &mut App,
290        ) {
291            assert!(
292                self.next_prompt_updates.lock().is_empty(),
293                "Use either send_update or set_next_prompt_updates"
294            );
295
296            self.sessions
297                .lock()
298                .get(&session_id)
299                .unwrap()
300                .thread
301                .update(cx, |thread, cx| {
302                    thread.handle_session_update(update, cx).unwrap();
303                })
304                .unwrap();
305        }
306
307        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
308            self.sessions
309                .lock()
310                .get_mut(&session_id)
311                .unwrap()
312                .response_tx
313                .take()
314                .expect("No pending turn")
315                .send(stop_reason)
316                .unwrap();
317        }
318    }
319
320    impl AgentConnection for StubAgentConnection {
321        fn auth_methods(&self) -> &[acp::AuthMethod] {
322            &[]
323        }
324
325        fn new_thread(
326            self: Rc<Self>,
327            project: Entity<Project>,
328            _cwd: &Path,
329            cx: &mut gpui::App,
330        ) -> Task<gpui::Result<Entity<AcpThread>>> {
331            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
332            let action_log = cx.new(|_| ActionLog::new(project.clone()));
333            let thread = cx.new(|cx| {
334                AcpThread::new(
335                    "Test",
336                    self.clone(),
337                    project,
338                    action_log,
339                    session_id.clone(),
340                    watch::Receiver::constant(acp::PromptCapabilities {
341                        image: true,
342                        audio: true,
343                        embedded_context: true,
344                        meta: None,
345                    }),
346                    cx,
347                )
348            });
349            self.sessions.lock().insert(
350                session_id,
351                Session {
352                    thread: thread.downgrade(),
353                    response_tx: None,
354                },
355            );
356            Task::ready(Ok(thread))
357        }
358
359        fn authenticate(
360            &self,
361            _method_id: acp::AuthMethodId,
362            _cx: &mut App,
363        ) -> Task<gpui::Result<()>> {
364            unimplemented!()
365        }
366
367        fn prompt(
368            &self,
369            _id: Option<UserMessageId>,
370            params: acp::PromptRequest,
371            cx: &mut App,
372        ) -> Task<gpui::Result<acp::PromptResponse>> {
373            let mut sessions = self.sessions.lock();
374            let Session {
375                thread,
376                response_tx,
377            } = sessions.get_mut(¶ms.session_id).unwrap();
378            let mut tasks = vec![];
379            if self.next_prompt_updates.lock().is_empty() {
380                let (tx, rx) = oneshot::channel();
381                response_tx.replace(tx);
382                cx.spawn(async move |_| {
383                    let stop_reason = rx.await?;
384                    Ok(acp::PromptResponse {
385                        stop_reason,
386                        meta: None,
387                    })
388                })
389            } else {
390                for update in self.next_prompt_updates.lock().drain(..) {
391                    let thread = thread.clone();
392                    let update = update.clone();
393                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
394                        &update
395                        && let Some(options) = self.permission_requests.get(&tool_call.id)
396                    {
397                        Some((tool_call.clone(), options.clone()))
398                    } else {
399                        None
400                    };
401                    let task = cx.spawn(async move |cx| {
402                        if let Some((tool_call, options)) = permission_request {
403                            thread
404                                .update(cx, |thread, cx| {
405                                    thread.request_tool_call_authorization(
406                                        tool_call.clone().into(),
407                                        options.clone(),
408                                        false,
409                                        cx,
410                                    )
411                                })??
412                                .await;
413                        }
414                        thread.update(cx, |thread, cx| {
415                            thread.handle_session_update(update.clone(), cx).unwrap();
416                        })?;
417                        anyhow::Ok(())
418                    });
419                    tasks.push(task);
420                }
421
422                cx.spawn(async move |_| {
423                    try_join_all(tasks).await?;
424                    Ok(acp::PromptResponse {
425                        stop_reason: acp::StopReason::EndTurn,
426                        meta: None,
427                    })
428                })
429            }
430        }
431
432        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
433            if let Some(end_turn_tx) = self
434                .sessions
435                .lock()
436                .get_mut(session_id)
437                .unwrap()
438                .response_tx
439                .take()
440            {
441                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
442            }
443        }
444
445        fn truncate(
446            &self,
447            _session_id: &agent_client_protocol::SessionId,
448            _cx: &App,
449        ) -> Option<Rc<dyn AgentSessionTruncate>> {
450            Some(Rc::new(StubAgentSessionEditor))
451        }
452
453        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
454            self
455        }
456    }
457
458    struct StubAgentSessionEditor;
459
460    impl AgentSessionTruncate for StubAgentSessionEditor {
461        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
462            Task::ready(Ok(()))
463        }
464    }
465}
466
467#[cfg(feature = "test-support")]
468pub use test_support::*;