connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn telemetry_id(&self) -> SharedString;
 24
 25    fn new_thread(
 26        self: Rc<Self>,
 27        project: Entity<Project>,
 28        cwd: &Path,
 29        cx: &mut App,
 30    ) -> Task<Result<Entity<AcpThread>>>;
 31
 32    fn auth_methods(&self) -> &[acp::AuthMethod];
 33
 34    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 35
 36    fn prompt(
 37        &self,
 38        user_message_id: Option<UserMessageId>,
 39        params: acp::PromptRequest,
 40        cx: &mut App,
 41    ) -> Task<Result<acp::PromptResponse>>;
 42
 43    fn resume(
 44        &self,
 45        _session_id: &acp::SessionId,
 46        _cx: &App,
 47    ) -> Option<Rc<dyn AgentSessionResume>> {
 48        None
 49    }
 50
 51    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 52
 53    fn truncate(
 54        &self,
 55        _session_id: &acp::SessionId,
 56        _cx: &App,
 57    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 58        None
 59    }
 60
 61    fn set_title(
 62        &self,
 63        _session_id: &acp::SessionId,
 64        _cx: &App,
 65    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 66        None
 67    }
 68
 69    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 70    ///
 71    /// If the agent does not support model selection, returns [None].
 72    /// This allows sharing the selector in UI components.
 73    fn model_selector(&self, _session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 74        None
 75    }
 76
 77    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 78        None
 79    }
 80
 81    fn session_modes(
 82        &self,
 83        _session_id: &acp::SessionId,
 84        _cx: &App,
 85    ) -> Option<Rc<dyn AgentSessionModes>> {
 86        None
 87    }
 88
 89    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 90}
 91
 92impl dyn AgentConnection {
 93    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 94        self.into_any().downcast().ok()
 95    }
 96}
 97
 98pub trait AgentSessionTruncate {
 99    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
100}
101
102pub trait AgentSessionResume {
103    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
104}
105
106pub trait AgentSessionSetTitle {
107    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
108}
109
110pub trait AgentTelemetry {
111    /// A representation of the current thread state that can be serialized for
112    /// storage with telemetry events.
113    fn thread_data(
114        &self,
115        session_id: &acp::SessionId,
116        cx: &mut App,
117    ) -> Task<Result<serde_json::Value>>;
118}
119
120pub trait AgentSessionModes {
121    fn current_mode(&self) -> acp::SessionModeId;
122
123    fn all_modes(&self) -> Vec<acp::SessionMode>;
124
125    fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
126}
127
128#[derive(Debug)]
129pub struct AuthRequired {
130    pub description: Option<String>,
131    pub provider_id: Option<LanguageModelProviderId>,
132}
133
134impl AuthRequired {
135    pub fn new() -> Self {
136        Self {
137            description: None,
138            provider_id: None,
139        }
140    }
141
142    pub fn with_description(mut self, description: String) -> Self {
143        self.description = Some(description);
144        self
145    }
146
147    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
148        self.provider_id = Some(provider_id);
149        self
150    }
151}
152
153impl Error for AuthRequired {}
154impl fmt::Display for AuthRequired {
155    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156        write!(f, "Authentication required")
157    }
158}
159
160/// Trait for agents that support listing, selecting, and querying language models.
161///
162/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
163pub trait AgentModelSelector: 'static {
164    /// Lists all available language models for this agent.
165    ///
166    /// # Parameters
167    /// - `cx`: The GPUI app context for async operations and global access.
168    ///
169    /// # Returns
170    /// A task resolving to the list of models or an error (e.g., if no models are configured).
171    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
172
173    /// Selects a model for a specific session (thread).
174    ///
175    /// This sets the default model for future interactions in the session.
176    /// If the session doesn't exist or the model is invalid, it returns an error.
177    ///
178    /// # Parameters
179    /// - `model`: The model to select (should be one from [list_models]).
180    /// - `cx`: The GPUI app context.
181    ///
182    /// # Returns
183    /// A task resolving to `Ok(())` on success or an error.
184    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>>;
185
186    /// Retrieves the currently selected model for a specific session (thread).
187    ///
188    /// # Parameters
189    /// - `cx`: The GPUI app context.
190    ///
191    /// # Returns
192    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
193    fn selected_model(&self, cx: &mut App) -> Task<Result<AgentModelInfo>>;
194
195    /// Whenever the model list is updated the receiver will be notified.
196    /// Optional for agents that don't update their model list.
197    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
198        None
199    }
200
201    /// Returns whether the model picker should render a footer.
202    fn should_render_footer(&self) -> bool {
203        false
204    }
205
206    /// Whether this selector supports the favorites feature.
207    /// Only the native agent uses the model ID format that maps to settings.
208    fn supports_favorites(&self) -> bool {
209        false
210    }
211}
212
213/// Icon for a model in the model selector.
214#[derive(Debug, Clone, PartialEq, Eq)]
215pub enum AgentModelIcon {
216    /// A built-in icon from Zed's icon set.
217    Named(IconName),
218    /// Path to a custom SVG icon file.
219    Path(SharedString),
220}
221
222#[derive(Debug, Clone, PartialEq, Eq)]
223pub struct AgentModelInfo {
224    pub id: acp::ModelId,
225    pub name: SharedString,
226    pub description: Option<SharedString>,
227    pub icon: Option<AgentModelIcon>,
228}
229
230impl From<acp::ModelInfo> for AgentModelInfo {
231    fn from(info: acp::ModelInfo) -> Self {
232        Self {
233            id: info.model_id,
234            name: info.name.into(),
235            description: info.description.map(|desc| desc.into()),
236            icon: None,
237        }
238    }
239}
240
241#[derive(Debug, Clone, PartialEq, Eq, Hash)]
242pub struct AgentModelGroupName(pub SharedString);
243
244#[derive(Debug, Clone)]
245pub enum AgentModelList {
246    Flat(Vec<AgentModelInfo>),
247    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
248}
249
250impl AgentModelList {
251    pub fn is_empty(&self) -> bool {
252        match self {
253            AgentModelList::Flat(models) => models.is_empty(),
254            AgentModelList::Grouped(groups) => groups.is_empty(),
255        }
256    }
257
258    pub fn is_flat(&self) -> bool {
259        matches!(self, AgentModelList::Flat(_))
260    }
261}
262
263#[cfg(feature = "test-support")]
264mod test_support {
265    use std::sync::Arc;
266
267    use action_log::ActionLog;
268    use collections::HashMap;
269    use futures::{channel::oneshot, future::try_join_all};
270    use gpui::{AppContext as _, WeakEntity};
271    use parking_lot::Mutex;
272
273    use super::*;
274
275    #[derive(Clone, Default)]
276    pub struct StubAgentConnection {
277        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
278        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
279        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
280    }
281
282    struct Session {
283        thread: WeakEntity<AcpThread>,
284        response_tx: Option<oneshot::Sender<acp::StopReason>>,
285    }
286
287    impl StubAgentConnection {
288        pub fn new() -> Self {
289            Self {
290                next_prompt_updates: Default::default(),
291                permission_requests: HashMap::default(),
292                sessions: Arc::default(),
293            }
294        }
295
296        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
297            *self.next_prompt_updates.lock() = updates;
298        }
299
300        pub fn with_permission_requests(
301            mut self,
302            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
303        ) -> Self {
304            self.permission_requests = permission_requests;
305            self
306        }
307
308        pub fn send_update(
309            &self,
310            session_id: acp::SessionId,
311            update: acp::SessionUpdate,
312            cx: &mut App,
313        ) {
314            assert!(
315                self.next_prompt_updates.lock().is_empty(),
316                "Use either send_update or set_next_prompt_updates"
317            );
318
319            self.sessions
320                .lock()
321                .get(&session_id)
322                .unwrap()
323                .thread
324                .update(cx, |thread, cx| {
325                    thread.handle_session_update(update, cx).unwrap();
326                })
327                .unwrap();
328        }
329
330        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
331            self.sessions
332                .lock()
333                .get_mut(&session_id)
334                .unwrap()
335                .response_tx
336                .take()
337                .expect("No pending turn")
338                .send(stop_reason)
339                .unwrap();
340        }
341    }
342
343    impl AgentConnection for StubAgentConnection {
344        fn telemetry_id(&self) -> SharedString {
345            "stub".into()
346        }
347
348        fn auth_methods(&self) -> &[acp::AuthMethod] {
349            &[]
350        }
351
352        fn new_thread(
353            self: Rc<Self>,
354            project: Entity<Project>,
355            _cwd: &Path,
356            cx: &mut gpui::App,
357        ) -> Task<gpui::Result<Entity<AcpThread>>> {
358            let session_id = acp::SessionId::new(self.sessions.lock().len().to_string());
359            let action_log = cx.new(|_| ActionLog::new(project.clone()));
360            let thread = cx.new(|cx| {
361                AcpThread::new(
362                    "Test",
363                    self.clone(),
364                    project,
365                    action_log,
366                    session_id.clone(),
367                    watch::Receiver::constant(
368                        acp::PromptCapabilities::new()
369                            .image(true)
370                            .audio(true)
371                            .embedded_context(true),
372                    ),
373                    cx,
374                )
375            });
376            self.sessions.lock().insert(
377                session_id,
378                Session {
379                    thread: thread.downgrade(),
380                    response_tx: None,
381                },
382            );
383            Task::ready(Ok(thread))
384        }
385
386        fn authenticate(
387            &self,
388            _method_id: acp::AuthMethodId,
389            _cx: &mut App,
390        ) -> Task<gpui::Result<()>> {
391            unimplemented!()
392        }
393
394        fn prompt(
395            &self,
396            _id: Option<UserMessageId>,
397            params: acp::PromptRequest,
398            cx: &mut App,
399        ) -> Task<gpui::Result<acp::PromptResponse>> {
400            let mut sessions = self.sessions.lock();
401            let Session {
402                thread,
403                response_tx,
404            } = sessions.get_mut(&params.session_id).unwrap();
405            let mut tasks = vec![];
406            if self.next_prompt_updates.lock().is_empty() {
407                let (tx, rx) = oneshot::channel();
408                response_tx.replace(tx);
409                cx.spawn(async move |_| {
410                    let stop_reason = rx.await?;
411                    Ok(acp::PromptResponse::new(stop_reason))
412                })
413            } else {
414                for update in self.next_prompt_updates.lock().drain(..) {
415                    let thread = thread.clone();
416                    let update = update.clone();
417                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
418                        &update
419                        && let Some(options) = self.permission_requests.get(&tool_call.tool_call_id)
420                    {
421                        Some((tool_call.clone(), options.clone()))
422                    } else {
423                        None
424                    };
425                    let task = cx.spawn(async move |cx| {
426                        if let Some((tool_call, options)) = permission_request {
427                            thread
428                                .update(cx, |thread, cx| {
429                                    thread.request_tool_call_authorization(
430                                        tool_call.clone().into(),
431                                        options.clone(),
432                                        false,
433                                        cx,
434                                    )
435                                })??
436                                .await;
437                        }
438                        thread.update(cx, |thread, cx| {
439                            thread.handle_session_update(update.clone(), cx).unwrap();
440                        })?;
441                        anyhow::Ok(())
442                    });
443                    tasks.push(task);
444                }
445
446                cx.spawn(async move |_| {
447                    try_join_all(tasks).await?;
448                    Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
449                })
450            }
451        }
452
453        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
454            if let Some(end_turn_tx) = self
455                .sessions
456                .lock()
457                .get_mut(session_id)
458                .unwrap()
459                .response_tx
460                .take()
461            {
462                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
463            }
464        }
465
466        fn truncate(
467            &self,
468            _session_id: &agent_client_protocol::SessionId,
469            _cx: &App,
470        ) -> Option<Rc<dyn AgentSessionTruncate>> {
471            Some(Rc::new(StubAgentSessionEditor))
472        }
473
474        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
475            self
476        }
477    }
478
479    struct StubAgentSessionEditor;
480
481    impl AgentSessionTruncate for StubAgentSessionEditor {
482        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
483            Task::ready(Ok(()))
484        }
485    }
486}
487
488#[cfg(feature = "test-support")]
489pub use test_support::*;