connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn telemetry_id(&self) -> SharedString;
 24
 25    fn new_thread(
 26        self: Rc<Self>,
 27        project: Entity<Project>,
 28        cwd: &Path,
 29        cx: &mut App,
 30    ) -> Task<Result<Entity<AcpThread>>>;
 31
 32    fn auth_methods(&self) -> &[acp::AuthMethod];
 33
 34    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 35
 36    fn prompt(
 37        &self,
 38        user_message_id: Option<UserMessageId>,
 39        params: acp::PromptRequest,
 40        cx: &mut App,
 41    ) -> Task<Result<acp::PromptResponse>>;
 42
 43    fn resume(
 44        &self,
 45        _session_id: &acp::SessionId,
 46        _cx: &App,
 47    ) -> Option<Rc<dyn AgentSessionResume>> {
 48        None
 49    }
 50
 51    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 52
 53    fn truncate(
 54        &self,
 55        _session_id: &acp::SessionId,
 56        _cx: &App,
 57    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 58        None
 59    }
 60
 61    fn set_title(
 62        &self,
 63        _session_id: &acp::SessionId,
 64        _cx: &App,
 65    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 66        None
 67    }
 68
 69    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 70    ///
 71    /// If the agent does not support model selection, returns [None].
 72    /// This allows sharing the selector in UI components.
 73    fn model_selector(&self, _session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 74        None
 75    }
 76
 77    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 78        None
 79    }
 80
 81    fn session_modes(
 82        &self,
 83        _session_id: &acp::SessionId,
 84        _cx: &App,
 85    ) -> Option<Rc<dyn AgentSessionModes>> {
 86        None
 87    }
 88
 89    fn session_config_options(
 90        &self,
 91        _session_id: &acp::SessionId,
 92        _cx: &App,
 93    ) -> Option<Rc<dyn AgentSessionConfigOptions>> {
 94        None
 95    }
 96
 97    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 98}
 99
100impl dyn AgentConnection {
101    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
102        self.into_any().downcast().ok()
103    }
104}
105
106pub trait AgentSessionTruncate {
107    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
108}
109
110pub trait AgentSessionResume {
111    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
112}
113
114pub trait AgentSessionSetTitle {
115    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
116}
117
118pub trait AgentTelemetry {
119    /// A representation of the current thread state that can be serialized for
120    /// storage with telemetry events.
121    fn thread_data(
122        &self,
123        session_id: &acp::SessionId,
124        cx: &mut App,
125    ) -> Task<Result<serde_json::Value>>;
126}
127
128pub trait AgentSessionModes {
129    fn current_mode(&self) -> acp::SessionModeId;
130
131    fn all_modes(&self) -> Vec<acp::SessionMode>;
132
133    fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
134}
135
136pub trait AgentSessionConfigOptions {
137    /// Get all current config options with their state
138    fn config_options(&self) -> Vec<acp::SessionConfigOption>;
139
140    /// Set a config option value
141    /// Returns the full updated list of config options
142    fn set_config_option(
143        &self,
144        config_id: acp::SessionConfigId,
145        value: acp::SessionConfigValueId,
146        cx: &mut App,
147    ) -> Task<Result<Vec<acp::SessionConfigOption>>>;
148
149    /// Whenever the config options are updated the receiver will be notified.
150    /// Optional for agents that don't update their config options dynamically.
151    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
152        None
153    }
154}
155
156#[derive(Debug)]
157pub struct AuthRequired {
158    pub description: Option<String>,
159    pub provider_id: Option<LanguageModelProviderId>,
160}
161
162impl AuthRequired {
163    pub fn new() -> Self {
164        Self {
165            description: None,
166            provider_id: None,
167        }
168    }
169
170    pub fn with_description(mut self, description: String) -> Self {
171        self.description = Some(description);
172        self
173    }
174
175    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
176        self.provider_id = Some(provider_id);
177        self
178    }
179}
180
181impl Error for AuthRequired {}
182impl fmt::Display for AuthRequired {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        write!(f, "Authentication required")
185    }
186}
187
188/// Trait for agents that support listing, selecting, and querying language models.
189///
190/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
191pub trait AgentModelSelector: 'static {
192    /// Lists all available language models for this agent.
193    ///
194    /// # Parameters
195    /// - `cx`: The GPUI app context for async operations and global access.
196    ///
197    /// # Returns
198    /// A task resolving to the list of models or an error (e.g., if no models are configured).
199    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
200
201    /// Selects a model for a specific session (thread).
202    ///
203    /// This sets the default model for future interactions in the session.
204    /// If the session doesn't exist or the model is invalid, it returns an error.
205    ///
206    /// # Parameters
207    /// - `model`: The model to select (should be one from [list_models]).
208    /// - `cx`: The GPUI app context.
209    ///
210    /// # Returns
211    /// A task resolving to `Ok(())` on success or an error.
212    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>>;
213
214    /// Retrieves the currently selected model for a specific session (thread).
215    ///
216    /// # Parameters
217    /// - `cx`: The GPUI app context.
218    ///
219    /// # Returns
220    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
221    fn selected_model(&self, cx: &mut App) -> Task<Result<AgentModelInfo>>;
222
223    /// Whenever the model list is updated the receiver will be notified.
224    /// Optional for agents that don't update their model list.
225    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
226        None
227    }
228
229    /// Returns whether the model picker should render a footer.
230    fn should_render_footer(&self) -> bool {
231        false
232    }
233}
234
235/// Icon for a model in the model selector.
236#[derive(Debug, Clone, PartialEq, Eq)]
237pub enum AgentModelIcon {
238    /// A built-in icon from Zed's icon set.
239    Named(IconName),
240    /// Path to a custom SVG icon file.
241    Path(SharedString),
242}
243
244#[derive(Debug, Clone, PartialEq, Eq)]
245pub struct AgentModelInfo {
246    pub id: acp::ModelId,
247    pub name: SharedString,
248    pub description: Option<SharedString>,
249    pub icon: Option<AgentModelIcon>,
250}
251
252impl From<acp::ModelInfo> for AgentModelInfo {
253    fn from(info: acp::ModelInfo) -> Self {
254        Self {
255            id: info.model_id,
256            name: info.name.into(),
257            description: info.description.map(|desc| desc.into()),
258            icon: None,
259        }
260    }
261}
262
263#[derive(Debug, Clone, PartialEq, Eq, Hash)]
264pub struct AgentModelGroupName(pub SharedString);
265
266#[derive(Debug, Clone)]
267pub enum AgentModelList {
268    Flat(Vec<AgentModelInfo>),
269    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
270}
271
272impl AgentModelList {
273    pub fn is_empty(&self) -> bool {
274        match self {
275            AgentModelList::Flat(models) => models.is_empty(),
276            AgentModelList::Grouped(groups) => groups.is_empty(),
277        }
278    }
279
280    pub fn is_flat(&self) -> bool {
281        matches!(self, AgentModelList::Flat(_))
282    }
283}
284
285#[cfg(feature = "test-support")]
286mod test_support {
287    use std::sync::Arc;
288
289    use action_log::ActionLog;
290    use collections::HashMap;
291    use futures::{channel::oneshot, future::try_join_all};
292    use gpui::{AppContext as _, WeakEntity};
293    use parking_lot::Mutex;
294
295    use super::*;
296
297    #[derive(Clone, Default)]
298    pub struct StubAgentConnection {
299        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
300        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
301        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
302    }
303
304    struct Session {
305        thread: WeakEntity<AcpThread>,
306        response_tx: Option<oneshot::Sender<acp::StopReason>>,
307    }
308
309    impl StubAgentConnection {
310        pub fn new() -> Self {
311            Self {
312                next_prompt_updates: Default::default(),
313                permission_requests: HashMap::default(),
314                sessions: Arc::default(),
315            }
316        }
317
318        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
319            *self.next_prompt_updates.lock() = updates;
320        }
321
322        pub fn with_permission_requests(
323            mut self,
324            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
325        ) -> Self {
326            self.permission_requests = permission_requests;
327            self
328        }
329
330        pub fn send_update(
331            &self,
332            session_id: acp::SessionId,
333            update: acp::SessionUpdate,
334            cx: &mut App,
335        ) {
336            assert!(
337                self.next_prompt_updates.lock().is_empty(),
338                "Use either send_update or set_next_prompt_updates"
339            );
340
341            self.sessions
342                .lock()
343                .get(&session_id)
344                .unwrap()
345                .thread
346                .update(cx, |thread, cx| {
347                    thread.handle_session_update(update, cx).unwrap();
348                })
349                .unwrap();
350        }
351
352        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
353            self.sessions
354                .lock()
355                .get_mut(&session_id)
356                .unwrap()
357                .response_tx
358                .take()
359                .expect("No pending turn")
360                .send(stop_reason)
361                .unwrap();
362        }
363    }
364
365    impl AgentConnection for StubAgentConnection {
366        fn telemetry_id(&self) -> SharedString {
367            "stub".into()
368        }
369
370        fn auth_methods(&self) -> &[acp::AuthMethod] {
371            &[]
372        }
373
374        fn new_thread(
375            self: Rc<Self>,
376            project: Entity<Project>,
377            _cwd: &Path,
378            cx: &mut gpui::App,
379        ) -> Task<gpui::Result<Entity<AcpThread>>> {
380            let session_id = acp::SessionId::new(self.sessions.lock().len().to_string());
381            let action_log = cx.new(|_| ActionLog::new(project.clone()));
382            let thread = cx.new(|cx| {
383                AcpThread::new(
384                    "Test",
385                    self.clone(),
386                    project,
387                    action_log,
388                    session_id.clone(),
389                    watch::Receiver::constant(
390                        acp::PromptCapabilities::new()
391                            .image(true)
392                            .audio(true)
393                            .embedded_context(true),
394                    ),
395                    cx,
396                )
397            });
398            self.sessions.lock().insert(
399                session_id,
400                Session {
401                    thread: thread.downgrade(),
402                    response_tx: None,
403                },
404            );
405            Task::ready(Ok(thread))
406        }
407
408        fn authenticate(
409            &self,
410            _method_id: acp::AuthMethodId,
411            _cx: &mut App,
412        ) -> Task<gpui::Result<()>> {
413            unimplemented!()
414        }
415
416        fn prompt(
417            &self,
418            _id: Option<UserMessageId>,
419            params: acp::PromptRequest,
420            cx: &mut App,
421        ) -> Task<gpui::Result<acp::PromptResponse>> {
422            let mut sessions = self.sessions.lock();
423            let Session {
424                thread,
425                response_tx,
426            } = sessions.get_mut(&params.session_id).unwrap();
427            let mut tasks = vec![];
428            if self.next_prompt_updates.lock().is_empty() {
429                let (tx, rx) = oneshot::channel();
430                response_tx.replace(tx);
431                cx.spawn(async move |_| {
432                    let stop_reason = rx.await?;
433                    Ok(acp::PromptResponse::new(stop_reason))
434                })
435            } else {
436                for update in self.next_prompt_updates.lock().drain(..) {
437                    let thread = thread.clone();
438                    let update = update.clone();
439                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
440                        &update
441                        && let Some(options) = self.permission_requests.get(&tool_call.tool_call_id)
442                    {
443                        Some((tool_call.clone(), options.clone()))
444                    } else {
445                        None
446                    };
447                    let task = cx.spawn(async move |cx| {
448                        if let Some((tool_call, options)) = permission_request {
449                            thread
450                                .update(cx, |thread, cx| {
451                                    thread.request_tool_call_authorization(
452                                        tool_call.clone().into(),
453                                        options.clone(),
454                                        false,
455                                        cx,
456                                    )
457                                })??
458                                .await;
459                        }
460                        thread.update(cx, |thread, cx| {
461                            thread.handle_session_update(update.clone(), cx).unwrap();
462                        })?;
463                        anyhow::Ok(())
464                    });
465                    tasks.push(task);
466                }
467
468                cx.spawn(async move |_| {
469                    try_join_all(tasks).await?;
470                    Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
471                })
472            }
473        }
474
475        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
476            if let Some(end_turn_tx) = self
477                .sessions
478                .lock()
479                .get_mut(session_id)
480                .unwrap()
481                .response_tx
482                .take()
483            {
484                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
485            }
486        }
487
488        fn truncate(
489            &self,
490            _session_id: &agent_client_protocol::SessionId,
491            _cx: &App,
492        ) -> Option<Rc<dyn AgentSessionTruncate>> {
493            Some(Rc::new(StubAgentSessionEditor))
494        }
495
496        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
497            self
498        }
499    }
500
501    struct StubAgentSessionEditor;
502
503    impl AgentSessionTruncate for StubAgentSessionEditor {
504        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
505            Task::ready(Ok(()))
506        }
507    }
508}
509
510#[cfg(feature = "test-support")]
511pub use test_support::*;