connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use chrono::{DateTime, Utc};
  5use collections::IndexMap;
  6use gpui::{Entity, SharedString, Task};
  7use language_model::LanguageModelProviderId;
  8use project::Project;
  9use serde::{Deserialize, Serialize};
 10use std::{
 11    any::Any,
 12    error::Error,
 13    fmt,
 14    path::{Path, PathBuf},
 15    rc::Rc,
 16    sync::Arc,
 17};
 18use ui::{App, IconName};
 19use uuid::Uuid;
 20
 21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 22pub struct UserMessageId(Arc<str>);
 23
 24impl UserMessageId {
 25    pub fn new() -> Self {
 26        Self(Uuid::new_v4().to_string().into())
 27    }
 28}
 29
 30pub trait AgentConnection {
 31    fn telemetry_id(&self) -> SharedString;
 32
 33    fn new_thread(
 34        self: Rc<Self>,
 35        project: Entity<Project>,
 36        cwd: &Path,
 37        cx: &mut App,
 38    ) -> Task<Result<Entity<AcpThread>>>;
 39
 40    /// Whether this agent supports loading existing sessions.
 41    fn supports_load_session(&self, _cx: &App) -> bool {
 42        false
 43    }
 44
 45    /// Load an existing session by ID.
 46    fn load_session(
 47        self: Rc<Self>,
 48        _session: AgentSessionInfo,
 49        _project: Entity<Project>,
 50        _cwd: &Path,
 51        _cx: &mut App,
 52    ) -> Task<Result<Entity<AcpThread>>> {
 53        Task::ready(Err(anyhow::Error::msg("Loading sessions is not supported")))
 54    }
 55
 56    fn auth_methods(&self) -> &[acp::AuthMethod];
 57
 58    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 59
 60    fn prompt(
 61        &self,
 62        user_message_id: Option<UserMessageId>,
 63        params: acp::PromptRequest,
 64        cx: &mut App,
 65    ) -> Task<Result<acp::PromptResponse>>;
 66
 67    fn resume(
 68        &self,
 69        _session_id: &acp::SessionId,
 70        _cx: &App,
 71    ) -> Option<Rc<dyn AgentSessionResume>> {
 72        None
 73    }
 74
 75    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 76
 77    fn truncate(
 78        &self,
 79        _session_id: &acp::SessionId,
 80        _cx: &App,
 81    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 82        None
 83    }
 84
 85    fn set_title(
 86        &self,
 87        _session_id: &acp::SessionId,
 88        _cx: &App,
 89    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 90        None
 91    }
 92
 93    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 94    ///
 95    /// If the agent does not support model selection, returns [None].
 96    /// This allows sharing the selector in UI components.
 97    fn model_selector(&self, _session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 98        None
 99    }
100
101    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
102        None
103    }
104
105    fn session_modes(
106        &self,
107        _session_id: &acp::SessionId,
108        _cx: &App,
109    ) -> Option<Rc<dyn AgentSessionModes>> {
110        None
111    }
112
113    fn session_config_options(
114        &self,
115        _session_id: &acp::SessionId,
116        _cx: &App,
117    ) -> Option<Rc<dyn AgentSessionConfigOptions>> {
118        None
119    }
120
121    fn session_list(&self, _cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
122        None
123    }
124
125    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
126}
127
128impl dyn AgentConnection {
129    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
130        self.into_any().downcast().ok()
131    }
132}
133
134pub trait AgentSessionTruncate {
135    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
136}
137
138pub trait AgentSessionResume {
139    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
140}
141
142pub trait AgentSessionSetTitle {
143    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
144}
145
146pub trait AgentTelemetry {
147    /// A representation of the current thread state that can be serialized for
148    /// storage with telemetry events.
149    fn thread_data(
150        &self,
151        session_id: &acp::SessionId,
152        cx: &mut App,
153    ) -> Task<Result<serde_json::Value>>;
154}
155
156pub trait AgentSessionModes {
157    fn current_mode(&self) -> acp::SessionModeId;
158
159    fn all_modes(&self) -> Vec<acp::SessionMode>;
160
161    fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
162}
163
164pub trait AgentSessionConfigOptions {
165    /// Get all current config options with their state
166    fn config_options(&self) -> Vec<acp::SessionConfigOption>;
167
168    /// Set a config option value
169    /// Returns the full updated list of config options
170    fn set_config_option(
171        &self,
172        config_id: acp::SessionConfigId,
173        value: acp::SessionConfigValueId,
174        cx: &mut App,
175    ) -> Task<Result<Vec<acp::SessionConfigOption>>>;
176
177    /// Whenever the config options are updated the receiver will be notified.
178    /// Optional for agents that don't update their config options dynamically.
179    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
180        None
181    }
182}
183
184#[derive(Debug, Clone, Default)]
185pub struct AgentSessionListRequest {
186    pub cwd: Option<PathBuf>,
187    pub cursor: Option<String>,
188    pub meta: Option<acp::Meta>,
189}
190
191#[derive(Debug, Clone)]
192pub struct AgentSessionListResponse {
193    pub sessions: Vec<AgentSessionInfo>,
194    pub next_cursor: Option<String>,
195    pub meta: Option<acp::Meta>,
196}
197
198impl AgentSessionListResponse {
199    pub fn new(sessions: Vec<AgentSessionInfo>) -> Self {
200        Self {
201            sessions,
202            next_cursor: None,
203            meta: None,
204        }
205    }
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub struct AgentSessionInfo {
210    pub session_id: acp::SessionId,
211    pub cwd: Option<PathBuf>,
212    pub title: Option<SharedString>,
213    pub updated_at: Option<DateTime<Utc>>,
214    pub meta: Option<acp::Meta>,
215}
216
217impl AgentSessionInfo {
218    pub fn new(session_id: impl Into<acp::SessionId>) -> Self {
219        Self {
220            session_id: session_id.into(),
221            cwd: None,
222            title: None,
223            updated_at: None,
224            meta: None,
225        }
226    }
227}
228
229#[derive(Debug, Clone)]
230pub enum SessionListUpdate {
231    Refresh,
232    SessionInfo {
233        session_id: acp::SessionId,
234        update: acp::SessionInfoUpdate,
235    },
236}
237
238pub trait AgentSessionList {
239    fn list_sessions(
240        &self,
241        request: AgentSessionListRequest,
242        cx: &mut App,
243    ) -> Task<Result<AgentSessionListResponse>>;
244
245    fn supports_delete(&self) -> bool {
246        false
247    }
248
249    fn delete_session(&self, _session_id: &acp::SessionId, _cx: &mut App) -> Task<Result<()>> {
250        Task::ready(Err(anyhow::anyhow!("delete_session not supported")))
251    }
252
253    fn delete_sessions(&self, _cx: &mut App) -> Task<Result<()>> {
254        Task::ready(Err(anyhow::anyhow!("delete_sessions not supported")))
255    }
256
257    fn watch(&self, _cx: &mut App) -> Option<smol::channel::Receiver<SessionListUpdate>> {
258        None
259    }
260
261    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
262}
263
264impl dyn AgentSessionList {
265    pub fn downcast<T: 'static + AgentSessionList + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
266        self.into_any().downcast().ok()
267    }
268}
269
270#[derive(Debug)]
271pub struct AuthRequired {
272    pub description: Option<String>,
273    pub provider_id: Option<LanguageModelProviderId>,
274}
275
276impl AuthRequired {
277    pub fn new() -> Self {
278        Self {
279            description: None,
280            provider_id: None,
281        }
282    }
283
284    pub fn with_description(mut self, description: String) -> Self {
285        self.description = Some(description);
286        self
287    }
288
289    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
290        self.provider_id = Some(provider_id);
291        self
292    }
293}
294
295impl Error for AuthRequired {}
296impl fmt::Display for AuthRequired {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        write!(f, "Authentication required")
299    }
300}
301
302/// Trait for agents that support listing, selecting, and querying language models.
303///
304/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
305pub trait AgentModelSelector: 'static {
306    /// Lists all available language models for this agent.
307    ///
308    /// # Parameters
309    /// - `cx`: The GPUI app context for async operations and global access.
310    ///
311    /// # Returns
312    /// A task resolving to the list of models or an error (e.g., if no models are configured).
313    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
314
315    /// Selects a model for a specific session (thread).
316    ///
317    /// This sets the default model for future interactions in the session.
318    /// If the session doesn't exist or the model is invalid, it returns an error.
319    ///
320    /// # Parameters
321    /// - `model`: The model to select (should be one from [list_models]).
322    /// - `cx`: The GPUI app context.
323    ///
324    /// # Returns
325    /// A task resolving to `Ok(())` on success or an error.
326    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>>;
327
328    /// Retrieves the currently selected model for a specific session (thread).
329    ///
330    /// # Parameters
331    /// - `cx`: The GPUI app context.
332    ///
333    /// # Returns
334    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
335    fn selected_model(&self, cx: &mut App) -> Task<Result<AgentModelInfo>>;
336
337    /// Whenever the model list is updated the receiver will be notified.
338    /// Optional for agents that don't update their model list.
339    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
340        None
341    }
342
343    /// Returns whether the model picker should render a footer.
344    fn should_render_footer(&self) -> bool {
345        false
346    }
347}
348
349/// Icon for a model in the model selector.
350#[derive(Debug, Clone, PartialEq, Eq)]
351pub enum AgentModelIcon {
352    /// A built-in icon from Zed's icon set.
353    Named(IconName),
354    /// Path to a custom SVG icon file.
355    Path(SharedString),
356}
357
358#[derive(Debug, Clone, PartialEq, Eq)]
359pub struct AgentModelInfo {
360    pub id: acp::ModelId,
361    pub name: SharedString,
362    pub description: Option<SharedString>,
363    pub icon: Option<AgentModelIcon>,
364}
365
366impl From<acp::ModelInfo> for AgentModelInfo {
367    fn from(info: acp::ModelInfo) -> Self {
368        Self {
369            id: info.model_id,
370            name: info.name.into(),
371            description: info.description.map(|desc| desc.into()),
372            icon: None,
373        }
374    }
375}
376
377#[derive(Debug, Clone, PartialEq, Eq, Hash)]
378pub struct AgentModelGroupName(pub SharedString);
379
380#[derive(Debug, Clone)]
381pub enum AgentModelList {
382    Flat(Vec<AgentModelInfo>),
383    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
384}
385
386impl AgentModelList {
387    pub fn is_empty(&self) -> bool {
388        match self {
389            AgentModelList::Flat(models) => models.is_empty(),
390            AgentModelList::Grouped(groups) => groups.is_empty(),
391        }
392    }
393
394    pub fn is_flat(&self) -> bool {
395        matches!(self, AgentModelList::Flat(_))
396    }
397}
398
399#[derive(Debug, Clone)]
400pub struct PermissionOptionChoice {
401    pub allow: acp::PermissionOption,
402    pub deny: acp::PermissionOption,
403}
404
405impl PermissionOptionChoice {
406    pub fn label(&self) -> SharedString {
407        self.allow.name.clone().into()
408    }
409}
410
411#[derive(Debug, Clone)]
412pub enum PermissionOptions {
413    Flat(Vec<acp::PermissionOption>),
414    Dropdown(Vec<PermissionOptionChoice>),
415}
416
417impl PermissionOptions {
418    pub fn is_empty(&self) -> bool {
419        match self {
420            PermissionOptions::Flat(options) => options.is_empty(),
421            PermissionOptions::Dropdown(options) => options.is_empty(),
422        }
423    }
424
425    pub fn first_option_of_kind(
426        &self,
427        kind: acp::PermissionOptionKind,
428    ) -> Option<&acp::PermissionOption> {
429        match self {
430            PermissionOptions::Flat(options) => options.iter().find(|option| option.kind == kind),
431            PermissionOptions::Dropdown(options) => options.iter().find_map(|choice| {
432                if choice.allow.kind == kind {
433                    Some(&choice.allow)
434                } else if choice.deny.kind == kind {
435                    Some(&choice.deny)
436                } else {
437                    None
438                }
439            }),
440        }
441    }
442
443    pub fn allow_once_option_id(&self) -> Option<acp::PermissionOptionId> {
444        self.first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
445            .map(|option| option.option_id.clone())
446    }
447}
448
449#[cfg(feature = "test-support")]
450mod test_support {
451    //! Test-only stubs and helpers for acp_thread.
452    //!
453    //! This module is gated by the `test-support` feature and is not included
454    //! in production builds. It provides:
455    //! - `StubAgentConnection` for mocking agent connections in tests
456    //! - `create_test_png_base64` for generating test images
457
458    use std::sync::Arc;
459
460    use action_log::ActionLog;
461    use collections::HashMap;
462    use futures::{channel::oneshot, future::try_join_all};
463    use gpui::{AppContext as _, WeakEntity};
464    use parking_lot::Mutex;
465
466    use super::*;
467
468    /// Creates a PNG image encoded as base64 for testing.
469    ///
470    /// Generates a solid-color PNG of the specified dimensions and returns
471    /// it as a base64-encoded string suitable for use in `ImageContent`.
472    pub fn create_test_png_base64(width: u32, height: u32, color: [u8; 4]) -> String {
473        use image::ImageEncoder as _;
474
475        let mut png_data = Vec::new();
476        {
477            let encoder = image::codecs::png::PngEncoder::new(&mut png_data);
478            let mut pixels = Vec::with_capacity((width * height * 4) as usize);
479            for _ in 0..(width * height) {
480                pixels.extend_from_slice(&color);
481            }
482            encoder
483                .write_image(&pixels, width, height, image::ExtendedColorType::Rgba8)
484                .expect("Failed to encode PNG");
485        }
486
487        use image::EncodableLayout as _;
488        base64::Engine::encode(
489            &base64::engine::general_purpose::STANDARD,
490            png_data.as_bytes(),
491        )
492    }
493
494    #[derive(Clone, Default)]
495    pub struct StubAgentConnection {
496        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
497        permission_requests: HashMap<acp::ToolCallId, PermissionOptions>,
498        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
499    }
500
501    struct Session {
502        thread: WeakEntity<AcpThread>,
503        response_tx: Option<oneshot::Sender<acp::StopReason>>,
504    }
505
506    impl StubAgentConnection {
507        pub fn new() -> Self {
508            Self {
509                next_prompt_updates: Default::default(),
510                permission_requests: HashMap::default(),
511                sessions: Arc::default(),
512            }
513        }
514
515        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
516            *self.next_prompt_updates.lock() = updates;
517        }
518
519        pub fn with_permission_requests(
520            mut self,
521            permission_requests: HashMap<acp::ToolCallId, PermissionOptions>,
522        ) -> Self {
523            self.permission_requests = permission_requests;
524            self
525        }
526
527        pub fn send_update(
528            &self,
529            session_id: acp::SessionId,
530            update: acp::SessionUpdate,
531            cx: &mut App,
532        ) {
533            assert!(
534                self.next_prompt_updates.lock().is_empty(),
535                "Use either send_update or set_next_prompt_updates"
536            );
537
538            self.sessions
539                .lock()
540                .get(&session_id)
541                .unwrap()
542                .thread
543                .update(cx, |thread, cx| {
544                    thread.handle_session_update(update, cx).unwrap();
545                })
546                .unwrap();
547        }
548
549        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
550            self.sessions
551                .lock()
552                .get_mut(&session_id)
553                .unwrap()
554                .response_tx
555                .take()
556                .expect("No pending turn")
557                .send(stop_reason)
558                .unwrap();
559        }
560    }
561
562    impl AgentConnection for StubAgentConnection {
563        fn telemetry_id(&self) -> SharedString {
564            "stub".into()
565        }
566
567        fn auth_methods(&self) -> &[acp::AuthMethod] {
568            &[]
569        }
570
571        fn model_selector(
572            &self,
573            _session_id: &acp::SessionId,
574        ) -> Option<Rc<dyn AgentModelSelector>> {
575            Some(self.model_selector_impl())
576        }
577
578        fn new_thread(
579            self: Rc<Self>,
580            project: Entity<Project>,
581            _cwd: &Path,
582            cx: &mut gpui::App,
583        ) -> Task<gpui::Result<Entity<AcpThread>>> {
584            let session_id = acp::SessionId::new(self.sessions.lock().len().to_string());
585            let action_log = cx.new(|_| ActionLog::new(project.clone()));
586            let thread = cx.new(|cx| {
587                AcpThread::new(
588                    "Test",
589                    self.clone(),
590                    project,
591                    action_log,
592                    session_id.clone(),
593                    watch::Receiver::constant(
594                        acp::PromptCapabilities::new()
595                            .image(true)
596                            .audio(true)
597                            .embedded_context(true),
598                    ),
599                    cx,
600                )
601            });
602            self.sessions.lock().insert(
603                session_id,
604                Session {
605                    thread: thread.downgrade(),
606                    response_tx: None,
607                },
608            );
609            Task::ready(Ok(thread))
610        }
611
612        fn authenticate(
613            &self,
614            _method_id: acp::AuthMethodId,
615            _cx: &mut App,
616        ) -> Task<gpui::Result<()>> {
617            unimplemented!()
618        }
619
620        fn prompt(
621            &self,
622            _id: Option<UserMessageId>,
623            params: acp::PromptRequest,
624            cx: &mut App,
625        ) -> Task<gpui::Result<acp::PromptResponse>> {
626            let mut sessions = self.sessions.lock();
627            let Session {
628                thread,
629                response_tx,
630            } = sessions.get_mut(&params.session_id).unwrap();
631            let mut tasks = vec![];
632            if self.next_prompt_updates.lock().is_empty() {
633                let (tx, rx) = oneshot::channel();
634                response_tx.replace(tx);
635                cx.spawn(async move |_| {
636                    let stop_reason = rx.await?;
637                    Ok(acp::PromptResponse::new(stop_reason))
638                })
639            } else {
640                for update in self.next_prompt_updates.lock().drain(..) {
641                    let thread = thread.clone();
642                    let update = update.clone();
643                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
644                        &update
645                        && let Some(options) = self.permission_requests.get(&tool_call.tool_call_id)
646                    {
647                        Some((tool_call.clone(), options.clone()))
648                    } else {
649                        None
650                    };
651                    let task = cx.spawn(async move |cx| {
652                        if let Some((tool_call, options)) = permission_request {
653                            thread
654                                .update(cx, |thread, cx| {
655                                    thread.request_tool_call_authorization(
656                                        tool_call.clone().into(),
657                                        options.clone(),
658                                        false,
659                                        cx,
660                                    )
661                                })??
662                                .await;
663                        }
664                        thread.update(cx, |thread, cx| {
665                            thread.handle_session_update(update.clone(), cx).unwrap();
666                        })?;
667                        anyhow::Ok(())
668                    });
669                    tasks.push(task);
670                }
671
672                cx.spawn(async move |_| {
673                    try_join_all(tasks).await?;
674                    Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
675                })
676            }
677        }
678
679        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
680            if let Some(end_turn_tx) = self
681                .sessions
682                .lock()
683                .get_mut(session_id)
684                .unwrap()
685                .response_tx
686                .take()
687            {
688                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
689            }
690        }
691
692        fn truncate(
693            &self,
694            _session_id: &agent_client_protocol::SessionId,
695            _cx: &App,
696        ) -> Option<Rc<dyn AgentSessionTruncate>> {
697            Some(Rc::new(StubAgentSessionEditor))
698        }
699
700        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
701            self
702        }
703    }
704
705    struct StubAgentSessionEditor;
706
707    impl AgentSessionTruncate for StubAgentSessionEditor {
708        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
709            Task::ready(Ok(()))
710        }
711    }
712
713    #[derive(Clone)]
714    struct StubModelSelector {
715        selected_model: Arc<Mutex<AgentModelInfo>>,
716    }
717
718    impl StubModelSelector {
719        fn new() -> Self {
720            Self {
721                selected_model: Arc::new(Mutex::new(AgentModelInfo {
722                    id: acp::ModelId::new("visual-test-model"),
723                    name: "Visual Test Model".into(),
724                    description: Some("A stub model for visual testing".into()),
725                    icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
726                })),
727            }
728        }
729    }
730
731    impl AgentModelSelector for StubModelSelector {
732        fn list_models(&self, _cx: &mut App) -> Task<Result<AgentModelList>> {
733            let model = self.selected_model.lock().clone();
734            Task::ready(Ok(AgentModelList::Flat(vec![model])))
735        }
736
737        fn select_model(&self, model_id: acp::ModelId, _cx: &mut App) -> Task<Result<()>> {
738            self.selected_model.lock().id = model_id;
739            Task::ready(Ok(()))
740        }
741
742        fn selected_model(&self, _cx: &mut App) -> Task<Result<AgentModelInfo>> {
743            Task::ready(Ok(self.selected_model.lock().clone()))
744        }
745    }
746
747    impl StubAgentConnection {
748        /// Returns a model selector for this stub connection.
749        pub fn model_selector_impl(&self) -> Rc<dyn AgentModelSelector> {
750            Rc::new(StubModelSelector::new())
751        }
752    }
753}
754
755#[cfg(feature = "test-support")]
756pub use test_support::*;