connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use chrono::{DateTime, Utc};
  5use collections::IndexMap;
  6use gpui::{Entity, SharedString, Task};
  7use language_model::LanguageModelProviderId;
  8use project::Project;
  9use serde::{Deserialize, Serialize};
 10use std::{
 11    any::Any,
 12    error::Error,
 13    fmt,
 14    path::{Path, PathBuf},
 15    rc::Rc,
 16    sync::Arc,
 17};
 18use ui::{App, IconName};
 19use uuid::Uuid;
 20
 21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 22pub struct UserMessageId(Arc<str>);
 23
 24impl UserMessageId {
 25    pub fn new() -> Self {
 26        Self(Uuid::new_v4().to_string().into())
 27    }
 28}
 29
 30pub trait AgentConnection {
 31    fn telemetry_id(&self) -> SharedString;
 32
 33    fn new_thread(
 34        self: Rc<Self>,
 35        project: Entity<Project>,
 36        cwd: &Path,
 37        cx: &mut App,
 38    ) -> Task<Result<Entity<AcpThread>>>;
 39
 40    fn auth_methods(&self) -> &[acp::AuthMethod];
 41
 42    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 43
 44    fn prompt(
 45        &self,
 46        user_message_id: Option<UserMessageId>,
 47        params: acp::PromptRequest,
 48        cx: &mut App,
 49    ) -> Task<Result<acp::PromptResponse>>;
 50
 51    fn resume(
 52        &self,
 53        _session_id: &acp::SessionId,
 54        _cx: &App,
 55    ) -> Option<Rc<dyn AgentSessionResume>> {
 56        None
 57    }
 58
 59    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 60
 61    fn truncate(
 62        &self,
 63        _session_id: &acp::SessionId,
 64        _cx: &App,
 65    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 66        None
 67    }
 68
 69    fn set_title(
 70        &self,
 71        _session_id: &acp::SessionId,
 72        _cx: &App,
 73    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 74        None
 75    }
 76
 77    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 78    ///
 79    /// If the agent does not support model selection, returns [None].
 80    /// This allows sharing the selector in UI components.
 81    fn model_selector(&self, _session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 82        None
 83    }
 84
 85    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 86        None
 87    }
 88
 89    fn session_modes(
 90        &self,
 91        _session_id: &acp::SessionId,
 92        _cx: &App,
 93    ) -> Option<Rc<dyn AgentSessionModes>> {
 94        None
 95    }
 96
 97    fn session_config_options(
 98        &self,
 99        _session_id: &acp::SessionId,
100        _cx: &App,
101    ) -> Option<Rc<dyn AgentSessionConfigOptions>> {
102        None
103    }
104
105    fn session_list(&self, _cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
106        None
107    }
108
109    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
110}
111
112impl dyn AgentConnection {
113    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
114        self.into_any().downcast().ok()
115    }
116}
117
118pub trait AgentSessionTruncate {
119    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
120}
121
122pub trait AgentSessionResume {
123    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
124}
125
126pub trait AgentSessionSetTitle {
127    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
128}
129
130pub trait AgentTelemetry {
131    /// A representation of the current thread state that can be serialized for
132    /// storage with telemetry events.
133    fn thread_data(
134        &self,
135        session_id: &acp::SessionId,
136        cx: &mut App,
137    ) -> Task<Result<serde_json::Value>>;
138}
139
140pub trait AgentSessionModes {
141    fn current_mode(&self) -> acp::SessionModeId;
142
143    fn all_modes(&self) -> Vec<acp::SessionMode>;
144
145    fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
146}
147
148pub trait AgentSessionConfigOptions {
149    /// Get all current config options with their state
150    fn config_options(&self) -> Vec<acp::SessionConfigOption>;
151
152    /// Set a config option value
153    /// Returns the full updated list of config options
154    fn set_config_option(
155        &self,
156        config_id: acp::SessionConfigId,
157        value: acp::SessionConfigValueId,
158        cx: &mut App,
159    ) -> Task<Result<Vec<acp::SessionConfigOption>>>;
160
161    /// Whenever the config options are updated the receiver will be notified.
162    /// Optional for agents that don't update their config options dynamically.
163    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
164        None
165    }
166}
167
168#[derive(Debug, Clone, Default)]
169pub struct AgentSessionListRequest {
170    pub cwd: Option<PathBuf>,
171    pub cursor: Option<String>,
172    pub meta: Option<acp::Meta>,
173}
174
175#[derive(Debug, Clone)]
176pub struct AgentSessionListResponse {
177    pub sessions: Vec<AgentSessionInfo>,
178    pub next_cursor: Option<String>,
179    pub meta: Option<acp::Meta>,
180}
181
182impl AgentSessionListResponse {
183    pub fn new(sessions: Vec<AgentSessionInfo>) -> Self {
184        Self {
185            sessions,
186            next_cursor: None,
187            meta: None,
188        }
189    }
190}
191
192#[derive(Debug, Clone, PartialEq)]
193pub struct AgentSessionInfo {
194    pub session_id: acp::SessionId,
195    pub cwd: Option<PathBuf>,
196    pub title: Option<SharedString>,
197    pub updated_at: Option<DateTime<Utc>>,
198    pub meta: Option<acp::Meta>,
199}
200
201impl AgentSessionInfo {
202    pub fn new(session_id: impl Into<acp::SessionId>) -> Self {
203        Self {
204            session_id: session_id.into(),
205            cwd: None,
206            title: None,
207            updated_at: None,
208            meta: None,
209        }
210    }
211}
212
213pub trait AgentSessionList {
214    fn list_sessions(
215        &self,
216        request: AgentSessionListRequest,
217        cx: &mut App,
218    ) -> Task<Result<AgentSessionListResponse>>;
219
220    fn supports_delete(&self) -> bool {
221        false
222    }
223
224    fn delete_session(&self, _session_id: &acp::SessionId, _cx: &mut App) -> Task<Result<()>> {
225        Task::ready(Err(anyhow::anyhow!("delete_session not supported")))
226    }
227
228    fn delete_sessions(&self, _cx: &mut App) -> Task<Result<()>> {
229        Task::ready(Err(anyhow::anyhow!("delete_sessions not supported")))
230    }
231
232    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
233        None
234    }
235
236    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
237}
238
239impl dyn AgentSessionList {
240    pub fn downcast<T: 'static + AgentSessionList + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
241        self.into_any().downcast().ok()
242    }
243}
244
245#[derive(Debug)]
246pub struct AuthRequired {
247    pub description: Option<String>,
248    pub provider_id: Option<LanguageModelProviderId>,
249}
250
251impl AuthRequired {
252    pub fn new() -> Self {
253        Self {
254            description: None,
255            provider_id: None,
256        }
257    }
258
259    pub fn with_description(mut self, description: String) -> Self {
260        self.description = Some(description);
261        self
262    }
263
264    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
265        self.provider_id = Some(provider_id);
266        self
267    }
268}
269
270impl Error for AuthRequired {}
271impl fmt::Display for AuthRequired {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        write!(f, "Authentication required")
274    }
275}
276
277/// Trait for agents that support listing, selecting, and querying language models.
278///
279/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
280pub trait AgentModelSelector: 'static {
281    /// Lists all available language models for this agent.
282    ///
283    /// # Parameters
284    /// - `cx`: The GPUI app context for async operations and global access.
285    ///
286    /// # Returns
287    /// A task resolving to the list of models or an error (e.g., if no models are configured).
288    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
289
290    /// Selects a model for a specific session (thread).
291    ///
292    /// This sets the default model for future interactions in the session.
293    /// If the session doesn't exist or the model is invalid, it returns an error.
294    ///
295    /// # Parameters
296    /// - `model`: The model to select (should be one from [list_models]).
297    /// - `cx`: The GPUI app context.
298    ///
299    /// # Returns
300    /// A task resolving to `Ok(())` on success or an error.
301    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>>;
302
303    /// Retrieves the currently selected model for a specific session (thread).
304    ///
305    /// # Parameters
306    /// - `cx`: The GPUI app context.
307    ///
308    /// # Returns
309    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
310    fn selected_model(&self, cx: &mut App) -> Task<Result<AgentModelInfo>>;
311
312    /// Whenever the model list is updated the receiver will be notified.
313    /// Optional for agents that don't update their model list.
314    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
315        None
316    }
317
318    /// Returns whether the model picker should render a footer.
319    fn should_render_footer(&self) -> bool {
320        false
321    }
322}
323
324/// Icon for a model in the model selector.
325#[derive(Debug, Clone, PartialEq, Eq)]
326pub enum AgentModelIcon {
327    /// A built-in icon from Zed's icon set.
328    Named(IconName),
329    /// Path to a custom SVG icon file.
330    Path(SharedString),
331}
332
333#[derive(Debug, Clone, PartialEq, Eq)]
334pub struct AgentModelInfo {
335    pub id: acp::ModelId,
336    pub name: SharedString,
337    pub description: Option<SharedString>,
338    pub icon: Option<AgentModelIcon>,
339}
340
341impl From<acp::ModelInfo> for AgentModelInfo {
342    fn from(info: acp::ModelInfo) -> Self {
343        Self {
344            id: info.model_id,
345            name: info.name.into(),
346            description: info.description.map(|desc| desc.into()),
347            icon: None,
348        }
349    }
350}
351
352#[derive(Debug, Clone, PartialEq, Eq, Hash)]
353pub struct AgentModelGroupName(pub SharedString);
354
355#[derive(Debug, Clone)]
356pub enum AgentModelList {
357    Flat(Vec<AgentModelInfo>),
358    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
359}
360
361impl AgentModelList {
362    pub fn is_empty(&self) -> bool {
363        match self {
364            AgentModelList::Flat(models) => models.is_empty(),
365            AgentModelList::Grouped(groups) => groups.is_empty(),
366        }
367    }
368
369    pub fn is_flat(&self) -> bool {
370        matches!(self, AgentModelList::Flat(_))
371    }
372}
373
374#[cfg(feature = "test-support")]
375mod test_support {
376    //! Test-only stubs and helpers for acp_thread.
377    //!
378    //! This module is gated by the `test-support` feature and is not included
379    //! in production builds. It provides:
380    //! - `StubAgentConnection` for mocking agent connections in tests
381    //! - `create_test_png_base64` for generating test images
382
383    use std::sync::Arc;
384
385    use action_log::ActionLog;
386    use collections::HashMap;
387    use futures::{channel::oneshot, future::try_join_all};
388    use gpui::{AppContext as _, WeakEntity};
389    use parking_lot::Mutex;
390
391    use super::*;
392
393    /// Creates a PNG image encoded as base64 for testing.
394    ///
395    /// Generates a solid-color PNG of the specified dimensions and returns
396    /// it as a base64-encoded string suitable for use in `ImageContent`.
397    pub fn create_test_png_base64(width: u32, height: u32, color: [u8; 4]) -> String {
398        use image::ImageEncoder as _;
399
400        let mut png_data = Vec::new();
401        {
402            let encoder = image::codecs::png::PngEncoder::new(&mut png_data);
403            let mut pixels = Vec::with_capacity((width * height * 4) as usize);
404            for _ in 0..(width * height) {
405                pixels.extend_from_slice(&color);
406            }
407            encoder
408                .write_image(&pixels, width, height, image::ExtendedColorType::Rgba8)
409                .expect("Failed to encode PNG");
410        }
411
412        use image::EncodableLayout as _;
413        base64::Engine::encode(
414            &base64::engine::general_purpose::STANDARD,
415            png_data.as_bytes(),
416        )
417    }
418
419    #[derive(Clone, Default)]
420    pub struct StubAgentConnection {
421        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
422        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
423        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
424    }
425
426    struct Session {
427        thread: WeakEntity<AcpThread>,
428        response_tx: Option<oneshot::Sender<acp::StopReason>>,
429    }
430
431    impl StubAgentConnection {
432        pub fn new() -> Self {
433            Self {
434                next_prompt_updates: Default::default(),
435                permission_requests: HashMap::default(),
436                sessions: Arc::default(),
437            }
438        }
439
440        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
441            *self.next_prompt_updates.lock() = updates;
442        }
443
444        pub fn with_permission_requests(
445            mut self,
446            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
447        ) -> Self {
448            self.permission_requests = permission_requests;
449            self
450        }
451
452        pub fn send_update(
453            &self,
454            session_id: acp::SessionId,
455            update: acp::SessionUpdate,
456            cx: &mut App,
457        ) {
458            assert!(
459                self.next_prompt_updates.lock().is_empty(),
460                "Use either send_update or set_next_prompt_updates"
461            );
462
463            self.sessions
464                .lock()
465                .get(&session_id)
466                .unwrap()
467                .thread
468                .update(cx, |thread, cx| {
469                    thread.handle_session_update(update, cx).unwrap();
470                })
471                .unwrap();
472        }
473
474        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
475            self.sessions
476                .lock()
477                .get_mut(&session_id)
478                .unwrap()
479                .response_tx
480                .take()
481                .expect("No pending turn")
482                .send(stop_reason)
483                .unwrap();
484        }
485    }
486
487    impl AgentConnection for StubAgentConnection {
488        fn telemetry_id(&self) -> SharedString {
489            "stub".into()
490        }
491
492        fn auth_methods(&self) -> &[acp::AuthMethod] {
493            &[]
494        }
495
496        fn model_selector(
497            &self,
498            _session_id: &acp::SessionId,
499        ) -> Option<Rc<dyn AgentModelSelector>> {
500            Some(self.model_selector_impl())
501        }
502
503        fn new_thread(
504            self: Rc<Self>,
505            project: Entity<Project>,
506            _cwd: &Path,
507            cx: &mut gpui::App,
508        ) -> Task<gpui::Result<Entity<AcpThread>>> {
509            let session_id = acp::SessionId::new(self.sessions.lock().len().to_string());
510            let action_log = cx.new(|_| ActionLog::new(project.clone()));
511            let thread = cx.new(|cx| {
512                AcpThread::new(
513                    "Test",
514                    self.clone(),
515                    project,
516                    action_log,
517                    session_id.clone(),
518                    watch::Receiver::constant(
519                        acp::PromptCapabilities::new()
520                            .image(true)
521                            .audio(true)
522                            .embedded_context(true),
523                    ),
524                    cx,
525                )
526            });
527            self.sessions.lock().insert(
528                session_id,
529                Session {
530                    thread: thread.downgrade(),
531                    response_tx: None,
532                },
533            );
534            Task::ready(Ok(thread))
535        }
536
537        fn authenticate(
538            &self,
539            _method_id: acp::AuthMethodId,
540            _cx: &mut App,
541        ) -> Task<gpui::Result<()>> {
542            unimplemented!()
543        }
544
545        fn prompt(
546            &self,
547            _id: Option<UserMessageId>,
548            params: acp::PromptRequest,
549            cx: &mut App,
550        ) -> Task<gpui::Result<acp::PromptResponse>> {
551            let mut sessions = self.sessions.lock();
552            let Session {
553                thread,
554                response_tx,
555            } = sessions.get_mut(&params.session_id).unwrap();
556            let mut tasks = vec![];
557            if self.next_prompt_updates.lock().is_empty() {
558                let (tx, rx) = oneshot::channel();
559                response_tx.replace(tx);
560                cx.spawn(async move |_| {
561                    let stop_reason = rx.await?;
562                    Ok(acp::PromptResponse::new(stop_reason))
563                })
564            } else {
565                for update in self.next_prompt_updates.lock().drain(..) {
566                    let thread = thread.clone();
567                    let update = update.clone();
568                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
569                        &update
570                        && let Some(options) = self.permission_requests.get(&tool_call.tool_call_id)
571                    {
572                        Some((tool_call.clone(), options.clone()))
573                    } else {
574                        None
575                    };
576                    let task = cx.spawn(async move |cx| {
577                        if let Some((tool_call, options)) = permission_request {
578                            thread
579                                .update(cx, |thread, cx| {
580                                    thread.request_tool_call_authorization(
581                                        tool_call.clone().into(),
582                                        options.clone(),
583                                        false,
584                                        cx,
585                                    )
586                                })??
587                                .await;
588                        }
589                        thread.update(cx, |thread, cx| {
590                            thread.handle_session_update(update.clone(), cx).unwrap();
591                        })?;
592                        anyhow::Ok(())
593                    });
594                    tasks.push(task);
595                }
596
597                cx.spawn(async move |_| {
598                    try_join_all(tasks).await?;
599                    Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
600                })
601            }
602        }
603
604        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
605            if let Some(end_turn_tx) = self
606                .sessions
607                .lock()
608                .get_mut(session_id)
609                .unwrap()
610                .response_tx
611                .take()
612            {
613                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
614            }
615        }
616
617        fn truncate(
618            &self,
619            _session_id: &agent_client_protocol::SessionId,
620            _cx: &App,
621        ) -> Option<Rc<dyn AgentSessionTruncate>> {
622            Some(Rc::new(StubAgentSessionEditor))
623        }
624
625        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
626            self
627        }
628    }
629
630    struct StubAgentSessionEditor;
631
632    impl AgentSessionTruncate for StubAgentSessionEditor {
633        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
634            Task::ready(Ok(()))
635        }
636    }
637
638    #[derive(Clone)]
639    struct StubModelSelector {
640        selected_model: Arc<Mutex<AgentModelInfo>>,
641    }
642
643    impl StubModelSelector {
644        fn new() -> Self {
645            Self {
646                selected_model: Arc::new(Mutex::new(AgentModelInfo {
647                    id: acp::ModelId::new("visual-test-model"),
648                    name: "Visual Test Model".into(),
649                    description: Some("A stub model for visual testing".into()),
650                    icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
651                })),
652            }
653        }
654    }
655
656    impl AgentModelSelector for StubModelSelector {
657        fn list_models(&self, _cx: &mut App) -> Task<Result<AgentModelList>> {
658            let model = self.selected_model.lock().clone();
659            Task::ready(Ok(AgentModelList::Flat(vec![model])))
660        }
661
662        fn select_model(&self, model_id: acp::ModelId, _cx: &mut App) -> Task<Result<()>> {
663            self.selected_model.lock().id = model_id;
664            Task::ready(Ok(()))
665        }
666
667        fn selected_model(&self, _cx: &mut App) -> Task<Result<AgentModelInfo>> {
668            Task::ready(Ok(self.selected_model.lock().clone()))
669        }
670    }
671
672    impl StubAgentConnection {
673        /// Returns a model selector for this stub connection.
674        pub fn model_selector_impl(&self) -> Rc<dyn AgentModelSelector> {
675            Rc::new(StubModelSelector::new())
676        }
677    }
678}
679
680#[cfg(feature = "test-support")]
681pub use test_support::*;