connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn telemetry_id(&self) -> SharedString;
 24
 25    fn new_thread(
 26        self: Rc<Self>,
 27        project: Entity<Project>,
 28        cwd: &Path,
 29        cx: &mut App,
 30    ) -> Task<Result<Entity<AcpThread>>>;
 31
 32    fn auth_methods(&self) -> &[acp::AuthMethod];
 33
 34    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 35
 36    fn prompt(
 37        &self,
 38        user_message_id: Option<UserMessageId>,
 39        params: acp::PromptRequest,
 40        cx: &mut App,
 41    ) -> Task<Result<acp::PromptResponse>>;
 42
 43    fn resume(
 44        &self,
 45        _session_id: &acp::SessionId,
 46        _cx: &App,
 47    ) -> Option<Rc<dyn AgentSessionResume>> {
 48        None
 49    }
 50
 51    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 52
 53    fn truncate(
 54        &self,
 55        _session_id: &acp::SessionId,
 56        _cx: &App,
 57    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 58        None
 59    }
 60
 61    fn set_title(
 62        &self,
 63        _session_id: &acp::SessionId,
 64        _cx: &App,
 65    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 66        None
 67    }
 68
 69    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 70    ///
 71    /// If the agent does not support model selection, returns [None].
 72    /// This allows sharing the selector in UI components.
 73    fn model_selector(&self, _session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 74        None
 75    }
 76
 77    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 78        None
 79    }
 80
 81    fn session_modes(
 82        &self,
 83        _session_id: &acp::SessionId,
 84        _cx: &App,
 85    ) -> Option<Rc<dyn AgentSessionModes>> {
 86        None
 87    }
 88
 89    fn session_config_options(
 90        &self,
 91        _session_id: &acp::SessionId,
 92        _cx: &App,
 93    ) -> Option<Rc<dyn AgentSessionConfigOptions>> {
 94        None
 95    }
 96
 97    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 98}
 99
100impl dyn AgentConnection {
101    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
102        self.into_any().downcast().ok()
103    }
104}
105
106pub trait AgentSessionTruncate {
107    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
108}
109
110pub trait AgentSessionResume {
111    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
112}
113
114pub trait AgentSessionSetTitle {
115    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
116}
117
118pub trait AgentTelemetry {
119    /// A representation of the current thread state that can be serialized for
120    /// storage with telemetry events.
121    fn thread_data(
122        &self,
123        session_id: &acp::SessionId,
124        cx: &mut App,
125    ) -> Task<Result<serde_json::Value>>;
126}
127
128pub trait AgentSessionModes {
129    fn current_mode(&self) -> acp::SessionModeId;
130
131    fn all_modes(&self) -> Vec<acp::SessionMode>;
132
133    fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
134}
135
136pub trait AgentSessionConfigOptions {
137    /// Get all current config options with their state
138    fn config_options(&self) -> Vec<acp::SessionConfigOption>;
139
140    /// Set a config option value
141    /// Returns the full updated list of config options
142    fn set_config_option(
143        &self,
144        config_id: acp::SessionConfigId,
145        value: acp::SessionConfigValueId,
146        cx: &mut App,
147    ) -> Task<Result<Vec<acp::SessionConfigOption>>>;
148
149    /// Whenever the config options are updated the receiver will be notified.
150    /// Optional for agents that don't update their config options dynamically.
151    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
152        None
153    }
154}
155
156#[derive(Debug)]
157pub struct AuthRequired {
158    pub description: Option<String>,
159    pub provider_id: Option<LanguageModelProviderId>,
160}
161
162impl AuthRequired {
163    pub fn new() -> Self {
164        Self {
165            description: None,
166            provider_id: None,
167        }
168    }
169
170    pub fn with_description(mut self, description: String) -> Self {
171        self.description = Some(description);
172        self
173    }
174
175    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
176        self.provider_id = Some(provider_id);
177        self
178    }
179}
180
181impl Error for AuthRequired {}
182impl fmt::Display for AuthRequired {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        write!(f, "Authentication required")
185    }
186}
187
188/// Trait for agents that support listing, selecting, and querying language models.
189///
190/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
191pub trait AgentModelSelector: 'static {
192    /// Lists all available language models for this agent.
193    ///
194    /// # Parameters
195    /// - `cx`: The GPUI app context for async operations and global access.
196    ///
197    /// # Returns
198    /// A task resolving to the list of models or an error (e.g., if no models are configured).
199    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
200
201    /// Selects a model for a specific session (thread).
202    ///
203    /// This sets the default model for future interactions in the session.
204    /// If the session doesn't exist or the model is invalid, it returns an error.
205    ///
206    /// # Parameters
207    /// - `model`: The model to select (should be one from [list_models]).
208    /// - `cx`: The GPUI app context.
209    ///
210    /// # Returns
211    /// A task resolving to `Ok(())` on success or an error.
212    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>>;
213
214    /// Retrieves the currently selected model for a specific session (thread).
215    ///
216    /// # Parameters
217    /// - `cx`: The GPUI app context.
218    ///
219    /// # Returns
220    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
221    fn selected_model(&self, cx: &mut App) -> Task<Result<AgentModelInfo>>;
222
223    /// Whenever the model list is updated the receiver will be notified.
224    /// Optional for agents that don't update their model list.
225    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
226        None
227    }
228
229    /// Returns whether the model picker should render a footer.
230    fn should_render_footer(&self) -> bool {
231        false
232    }
233}
234
235/// Icon for a model in the model selector.
236#[derive(Debug, Clone, PartialEq, Eq)]
237pub enum AgentModelIcon {
238    /// A built-in icon from Zed's icon set.
239    Named(IconName),
240    /// Path to a custom SVG icon file.
241    Path(SharedString),
242}
243
244#[derive(Debug, Clone, PartialEq, Eq)]
245pub struct AgentModelInfo {
246    pub id: acp::ModelId,
247    pub name: SharedString,
248    pub description: Option<SharedString>,
249    pub icon: Option<AgentModelIcon>,
250}
251
252impl From<acp::ModelInfo> for AgentModelInfo {
253    fn from(info: acp::ModelInfo) -> Self {
254        Self {
255            id: info.model_id,
256            name: info.name.into(),
257            description: info.description.map(|desc| desc.into()),
258            icon: None,
259        }
260    }
261}
262
263#[derive(Debug, Clone, PartialEq, Eq, Hash)]
264pub struct AgentModelGroupName(pub SharedString);
265
266#[derive(Debug, Clone)]
267pub enum AgentModelList {
268    Flat(Vec<AgentModelInfo>),
269    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
270}
271
272impl AgentModelList {
273    pub fn is_empty(&self) -> bool {
274        match self {
275            AgentModelList::Flat(models) => models.is_empty(),
276            AgentModelList::Grouped(groups) => groups.is_empty(),
277        }
278    }
279
280    pub fn is_flat(&self) -> bool {
281        matches!(self, AgentModelList::Flat(_))
282    }
283}
284
285#[cfg(feature = "test-support")]
286mod test_support {
287    //! Test-only stubs and helpers for acp_thread.
288    //!
289    //! This module is gated by the `test-support` feature and is not included
290    //! in production builds. It provides:
291    //! - `StubAgentConnection` for mocking agent connections in tests
292    //! - `create_test_png_base64` for generating test images
293
294    use std::sync::Arc;
295
296    use action_log::ActionLog;
297    use collections::HashMap;
298    use futures::{channel::oneshot, future::try_join_all};
299    use gpui::{AppContext as _, WeakEntity};
300    use parking_lot::Mutex;
301
302    use super::*;
303
304    /// Creates a PNG image encoded as base64 for testing.
305    ///
306    /// Generates a solid-color PNG of the specified dimensions and returns
307    /// it as a base64-encoded string suitable for use in `ImageContent`.
308    pub fn create_test_png_base64(width: u32, height: u32, color: [u8; 4]) -> String {
309        use image::ImageEncoder as _;
310
311        let mut png_data = Vec::new();
312        {
313            let encoder = image::codecs::png::PngEncoder::new(&mut png_data);
314            let mut pixels = Vec::with_capacity((width * height * 4) as usize);
315            for _ in 0..(width * height) {
316                pixels.extend_from_slice(&color);
317            }
318            encoder
319                .write_image(&pixels, width, height, image::ExtendedColorType::Rgba8)
320                .expect("Failed to encode PNG");
321        }
322
323        use image::EncodableLayout as _;
324        base64::Engine::encode(
325            &base64::engine::general_purpose::STANDARD,
326            png_data.as_bytes(),
327        )
328    }
329
330    #[derive(Clone, Default)]
331    pub struct StubAgentConnection {
332        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
333        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
334        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
335    }
336
337    struct Session {
338        thread: WeakEntity<AcpThread>,
339        response_tx: Option<oneshot::Sender<acp::StopReason>>,
340    }
341
342    impl StubAgentConnection {
343        pub fn new() -> Self {
344            Self {
345                next_prompt_updates: Default::default(),
346                permission_requests: HashMap::default(),
347                sessions: Arc::default(),
348            }
349        }
350
351        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
352            *self.next_prompt_updates.lock() = updates;
353        }
354
355        pub fn with_permission_requests(
356            mut self,
357            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
358        ) -> Self {
359            self.permission_requests = permission_requests;
360            self
361        }
362
363        pub fn send_update(
364            &self,
365            session_id: acp::SessionId,
366            update: acp::SessionUpdate,
367            cx: &mut App,
368        ) {
369            assert!(
370                self.next_prompt_updates.lock().is_empty(),
371                "Use either send_update or set_next_prompt_updates"
372            );
373
374            self.sessions
375                .lock()
376                .get(&session_id)
377                .unwrap()
378                .thread
379                .update(cx, |thread, cx| {
380                    thread.handle_session_update(update, cx).unwrap();
381                })
382                .unwrap();
383        }
384
385        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
386            self.sessions
387                .lock()
388                .get_mut(&session_id)
389                .unwrap()
390                .response_tx
391                .take()
392                .expect("No pending turn")
393                .send(stop_reason)
394                .unwrap();
395        }
396    }
397
398    impl AgentConnection for StubAgentConnection {
399        fn telemetry_id(&self) -> SharedString {
400            "stub".into()
401        }
402
403        fn auth_methods(&self) -> &[acp::AuthMethod] {
404            &[]
405        }
406
407        fn model_selector(
408            &self,
409            _session_id: &acp::SessionId,
410        ) -> Option<Rc<dyn AgentModelSelector>> {
411            Some(self.model_selector_impl())
412        }
413
414        fn new_thread(
415            self: Rc<Self>,
416            project: Entity<Project>,
417            _cwd: &Path,
418            cx: &mut gpui::App,
419        ) -> Task<gpui::Result<Entity<AcpThread>>> {
420            let session_id = acp::SessionId::new(self.sessions.lock().len().to_string());
421            let action_log = cx.new(|_| ActionLog::new(project.clone()));
422            let thread = cx.new(|cx| {
423                AcpThread::new(
424                    "Test",
425                    self.clone(),
426                    project,
427                    action_log,
428                    session_id.clone(),
429                    watch::Receiver::constant(
430                        acp::PromptCapabilities::new()
431                            .image(true)
432                            .audio(true)
433                            .embedded_context(true),
434                    ),
435                    cx,
436                )
437            });
438            self.sessions.lock().insert(
439                session_id,
440                Session {
441                    thread: thread.downgrade(),
442                    response_tx: None,
443                },
444            );
445            Task::ready(Ok(thread))
446        }
447
448        fn authenticate(
449            &self,
450            _method_id: acp::AuthMethodId,
451            _cx: &mut App,
452        ) -> Task<gpui::Result<()>> {
453            unimplemented!()
454        }
455
456        fn prompt(
457            &self,
458            _id: Option<UserMessageId>,
459            params: acp::PromptRequest,
460            cx: &mut App,
461        ) -> Task<gpui::Result<acp::PromptResponse>> {
462            let mut sessions = self.sessions.lock();
463            let Session {
464                thread,
465                response_tx,
466            } = sessions.get_mut(&params.session_id).unwrap();
467            let mut tasks = vec![];
468            if self.next_prompt_updates.lock().is_empty() {
469                let (tx, rx) = oneshot::channel();
470                response_tx.replace(tx);
471                cx.spawn(async move |_| {
472                    let stop_reason = rx.await?;
473                    Ok(acp::PromptResponse::new(stop_reason))
474                })
475            } else {
476                for update in self.next_prompt_updates.lock().drain(..) {
477                    let thread = thread.clone();
478                    let update = update.clone();
479                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
480                        &update
481                        && let Some(options) = self.permission_requests.get(&tool_call.tool_call_id)
482                    {
483                        Some((tool_call.clone(), options.clone()))
484                    } else {
485                        None
486                    };
487                    let task = cx.spawn(async move |cx| {
488                        if let Some((tool_call, options)) = permission_request {
489                            thread
490                                .update(cx, |thread, cx| {
491                                    thread.request_tool_call_authorization(
492                                        tool_call.clone().into(),
493                                        options.clone(),
494                                        false,
495                                        cx,
496                                    )
497                                })??
498                                .await;
499                        }
500                        thread.update(cx, |thread, cx| {
501                            thread.handle_session_update(update.clone(), cx).unwrap();
502                        })?;
503                        anyhow::Ok(())
504                    });
505                    tasks.push(task);
506                }
507
508                cx.spawn(async move |_| {
509                    try_join_all(tasks).await?;
510                    Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
511                })
512            }
513        }
514
515        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
516            if let Some(end_turn_tx) = self
517                .sessions
518                .lock()
519                .get_mut(session_id)
520                .unwrap()
521                .response_tx
522                .take()
523            {
524                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
525            }
526        }
527
528        fn truncate(
529            &self,
530            _session_id: &agent_client_protocol::SessionId,
531            _cx: &App,
532        ) -> Option<Rc<dyn AgentSessionTruncate>> {
533            Some(Rc::new(StubAgentSessionEditor))
534        }
535
536        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
537            self
538        }
539    }
540
541    struct StubAgentSessionEditor;
542
543    impl AgentSessionTruncate for StubAgentSessionEditor {
544        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
545            Task::ready(Ok(()))
546        }
547    }
548
549    #[derive(Clone)]
550    struct StubModelSelector {
551        selected_model: Arc<Mutex<AgentModelInfo>>,
552    }
553
554    impl StubModelSelector {
555        fn new() -> Self {
556            Self {
557                selected_model: Arc::new(Mutex::new(AgentModelInfo {
558                    id: acp::ModelId::new("visual-test-model"),
559                    name: "Visual Test Model".into(),
560                    description: Some("A stub model for visual testing".into()),
561                    icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
562                })),
563            }
564        }
565    }
566
567    impl AgentModelSelector for StubModelSelector {
568        fn list_models(&self, _cx: &mut App) -> Task<Result<AgentModelList>> {
569            let model = self.selected_model.lock().clone();
570            Task::ready(Ok(AgentModelList::Flat(vec![model])))
571        }
572
573        fn select_model(&self, model_id: acp::ModelId, _cx: &mut App) -> Task<Result<()>> {
574            self.selected_model.lock().id = model_id;
575            Task::ready(Ok(()))
576        }
577
578        fn selected_model(&self, _cx: &mut App) -> Task<Result<AgentModelInfo>> {
579            Task::ready(Ok(self.selected_model.lock().clone()))
580        }
581    }
582
583    impl StubAgentConnection {
584        /// Returns a model selector for this stub connection.
585        pub fn model_selector_impl(&self) -> Rc<dyn AgentModelSelector> {
586            Rc::new(StubModelSelector::new())
587        }
588    }
589}
590
591#[cfg(feature = "test-support")]
592pub use test_support::*;