connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use chrono::{DateTime, Utc};
  5use collections::IndexMap;
  6use gpui::{Entity, SharedString, Task};
  7use language_model::LanguageModelProviderId;
  8use project::Project;
  9use serde::{Deserialize, Serialize};
 10use std::{
 11    any::Any,
 12    error::Error,
 13    fmt,
 14    path::{Path, PathBuf},
 15    rc::Rc,
 16    sync::Arc,
 17};
 18use ui::{App, IconName};
 19use uuid::Uuid;
 20
 21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 22pub struct UserMessageId(Arc<str>);
 23
 24impl UserMessageId {
 25    pub fn new() -> Self {
 26        Self(Uuid::new_v4().to_string().into())
 27    }
 28}
 29
 30pub trait AgentConnection {
 31    fn telemetry_id(&self) -> SharedString;
 32
 33    fn new_thread(
 34        self: Rc<Self>,
 35        project: Entity<Project>,
 36        cwd: &Path,
 37        cx: &mut App,
 38    ) -> Task<Result<Entity<AcpThread>>>;
 39
 40    /// Whether this agent supports loading existing sessions.
 41    fn supports_load_session(&self, _cx: &App) -> bool {
 42        false
 43    }
 44
 45    /// Load an existing session by ID.
 46    fn load_session(
 47        self: Rc<Self>,
 48        _session: AgentSessionInfo,
 49        _project: Entity<Project>,
 50        _cwd: &Path,
 51        _cx: &mut App,
 52    ) -> Task<Result<Entity<AcpThread>>> {
 53        Task::ready(Err(anyhow::Error::msg("Loading sessions is not supported")))
 54    }
 55
 56    fn auth_methods(&self) -> &[acp::AuthMethod];
 57
 58    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 59
 60    fn prompt(
 61        &self,
 62        user_message_id: Option<UserMessageId>,
 63        params: acp::PromptRequest,
 64        cx: &mut App,
 65    ) -> Task<Result<acp::PromptResponse>>;
 66
 67    fn resume(
 68        &self,
 69        _session_id: &acp::SessionId,
 70        _cx: &App,
 71    ) -> Option<Rc<dyn AgentSessionResume>> {
 72        None
 73    }
 74
 75    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 76
 77    fn truncate(
 78        &self,
 79        _session_id: &acp::SessionId,
 80        _cx: &App,
 81    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 82        None
 83    }
 84
 85    fn set_title(
 86        &self,
 87        _session_id: &acp::SessionId,
 88        _cx: &App,
 89    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 90        None
 91    }
 92
 93    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 94    ///
 95    /// If the agent does not support model selection, returns [None].
 96    /// This allows sharing the selector in UI components.
 97    fn model_selector(&self, _session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 98        None
 99    }
100
101    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
102        None
103    }
104
105    fn session_modes(
106        &self,
107        _session_id: &acp::SessionId,
108        _cx: &App,
109    ) -> Option<Rc<dyn AgentSessionModes>> {
110        None
111    }
112
113    fn session_config_options(
114        &self,
115        _session_id: &acp::SessionId,
116        _cx: &App,
117    ) -> Option<Rc<dyn AgentSessionConfigOptions>> {
118        None
119    }
120
121    fn session_list(&self, _cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
122        None
123    }
124
125    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
126}
127
128impl dyn AgentConnection {
129    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
130        self.into_any().downcast().ok()
131    }
132}
133
134pub trait AgentSessionTruncate {
135    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
136}
137
138pub trait AgentSessionResume {
139    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
140}
141
142pub trait AgentSessionSetTitle {
143    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
144}
145
146pub trait AgentTelemetry {
147    /// A representation of the current thread state that can be serialized for
148    /// storage with telemetry events.
149    fn thread_data(
150        &self,
151        session_id: &acp::SessionId,
152        cx: &mut App,
153    ) -> Task<Result<serde_json::Value>>;
154}
155
156pub trait AgentSessionModes {
157    fn current_mode(&self) -> acp::SessionModeId;
158
159    fn all_modes(&self) -> Vec<acp::SessionMode>;
160
161    fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
162}
163
164pub trait AgentSessionConfigOptions {
165    /// Get all current config options with their state
166    fn config_options(&self) -> Vec<acp::SessionConfigOption>;
167
168    /// Set a config option value
169    /// Returns the full updated list of config options
170    fn set_config_option(
171        &self,
172        config_id: acp::SessionConfigId,
173        value: acp::SessionConfigValueId,
174        cx: &mut App,
175    ) -> Task<Result<Vec<acp::SessionConfigOption>>>;
176
177    /// Whenever the config options are updated the receiver will be notified.
178    /// Optional for agents that don't update their config options dynamically.
179    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
180        None
181    }
182}
183
184#[derive(Debug, Clone, Default)]
185pub struct AgentSessionListRequest {
186    pub cwd: Option<PathBuf>,
187    pub cursor: Option<String>,
188    pub meta: Option<acp::Meta>,
189}
190
191#[derive(Debug, Clone)]
192pub struct AgentSessionListResponse {
193    pub sessions: Vec<AgentSessionInfo>,
194    pub next_cursor: Option<String>,
195    pub meta: Option<acp::Meta>,
196}
197
198impl AgentSessionListResponse {
199    pub fn new(sessions: Vec<AgentSessionInfo>) -> Self {
200        Self {
201            sessions,
202            next_cursor: None,
203            meta: None,
204        }
205    }
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub struct AgentSessionInfo {
210    pub session_id: acp::SessionId,
211    pub cwd: Option<PathBuf>,
212    pub title: Option<SharedString>,
213    pub updated_at: Option<DateTime<Utc>>,
214    pub meta: Option<acp::Meta>,
215}
216
217impl AgentSessionInfo {
218    pub fn new(session_id: impl Into<acp::SessionId>) -> Self {
219        Self {
220            session_id: session_id.into(),
221            cwd: None,
222            title: None,
223            updated_at: None,
224            meta: None,
225        }
226    }
227}
228
229pub trait AgentSessionList {
230    fn list_sessions(
231        &self,
232        request: AgentSessionListRequest,
233        cx: &mut App,
234    ) -> Task<Result<AgentSessionListResponse>>;
235
236    fn supports_delete(&self) -> bool {
237        false
238    }
239
240    fn delete_session(&self, _session_id: &acp::SessionId, _cx: &mut App) -> Task<Result<()>> {
241        Task::ready(Err(anyhow::anyhow!("delete_session not supported")))
242    }
243
244    fn delete_sessions(&self, _cx: &mut App) -> Task<Result<()>> {
245        Task::ready(Err(anyhow::anyhow!("delete_sessions not supported")))
246    }
247
248    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
249        None
250    }
251
252    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
253}
254
255impl dyn AgentSessionList {
256    pub fn downcast<T: 'static + AgentSessionList + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
257        self.into_any().downcast().ok()
258    }
259}
260
261#[derive(Debug)]
262pub struct AuthRequired {
263    pub description: Option<String>,
264    pub provider_id: Option<LanguageModelProviderId>,
265}
266
267impl AuthRequired {
268    pub fn new() -> Self {
269        Self {
270            description: None,
271            provider_id: None,
272        }
273    }
274
275    pub fn with_description(mut self, description: String) -> Self {
276        self.description = Some(description);
277        self
278    }
279
280    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
281        self.provider_id = Some(provider_id);
282        self
283    }
284}
285
286impl Error for AuthRequired {}
287impl fmt::Display for AuthRequired {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        write!(f, "Authentication required")
290    }
291}
292
293/// Trait for agents that support listing, selecting, and querying language models.
294///
295/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
296pub trait AgentModelSelector: 'static {
297    /// Lists all available language models for this agent.
298    ///
299    /// # Parameters
300    /// - `cx`: The GPUI app context for async operations and global access.
301    ///
302    /// # Returns
303    /// A task resolving to the list of models or an error (e.g., if no models are configured).
304    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
305
306    /// Selects a model for a specific session (thread).
307    ///
308    /// This sets the default model for future interactions in the session.
309    /// If the session doesn't exist or the model is invalid, it returns an error.
310    ///
311    /// # Parameters
312    /// - `model`: The model to select (should be one from [list_models]).
313    /// - `cx`: The GPUI app context.
314    ///
315    /// # Returns
316    /// A task resolving to `Ok(())` on success or an error.
317    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>>;
318
319    /// Retrieves the currently selected model for a specific session (thread).
320    ///
321    /// # Parameters
322    /// - `cx`: The GPUI app context.
323    ///
324    /// # Returns
325    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
326    fn selected_model(&self, cx: &mut App) -> Task<Result<AgentModelInfo>>;
327
328    /// Whenever the model list is updated the receiver will be notified.
329    /// Optional for agents that don't update their model list.
330    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
331        None
332    }
333
334    /// Returns whether the model picker should render a footer.
335    fn should_render_footer(&self) -> bool {
336        false
337    }
338}
339
340/// Icon for a model in the model selector.
341#[derive(Debug, Clone, PartialEq, Eq)]
342pub enum AgentModelIcon {
343    /// A built-in icon from Zed's icon set.
344    Named(IconName),
345    /// Path to a custom SVG icon file.
346    Path(SharedString),
347}
348
349#[derive(Debug, Clone, PartialEq, Eq)]
350pub struct AgentModelInfo {
351    pub id: acp::ModelId,
352    pub name: SharedString,
353    pub description: Option<SharedString>,
354    pub icon: Option<AgentModelIcon>,
355}
356
357impl From<acp::ModelInfo> for AgentModelInfo {
358    fn from(info: acp::ModelInfo) -> Self {
359        Self {
360            id: info.model_id,
361            name: info.name.into(),
362            description: info.description.map(|desc| desc.into()),
363            icon: None,
364        }
365    }
366}
367
368#[derive(Debug, Clone, PartialEq, Eq, Hash)]
369pub struct AgentModelGroupName(pub SharedString);
370
371#[derive(Debug, Clone)]
372pub enum AgentModelList {
373    Flat(Vec<AgentModelInfo>),
374    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
375}
376
377impl AgentModelList {
378    pub fn is_empty(&self) -> bool {
379        match self {
380            AgentModelList::Flat(models) => models.is_empty(),
381            AgentModelList::Grouped(groups) => groups.is_empty(),
382        }
383    }
384
385    pub fn is_flat(&self) -> bool {
386        matches!(self, AgentModelList::Flat(_))
387    }
388}
389
390#[derive(Debug, Clone)]
391pub struct PermissionOptionChoice {
392    pub allow: acp::PermissionOption,
393    pub deny: acp::PermissionOption,
394}
395
396impl PermissionOptionChoice {
397    pub fn label(&self) -> SharedString {
398        self.allow.name.clone().into()
399    }
400}
401
402#[derive(Debug, Clone)]
403pub enum PermissionOptions {
404    Flat(Vec<acp::PermissionOption>),
405    Dropdown(Vec<PermissionOptionChoice>),
406}
407
408impl PermissionOptions {
409    pub fn is_empty(&self) -> bool {
410        match self {
411            PermissionOptions::Flat(options) => options.is_empty(),
412            PermissionOptions::Dropdown(options) => options.is_empty(),
413        }
414    }
415
416    pub fn first_option_of_kind(
417        &self,
418        kind: acp::PermissionOptionKind,
419    ) -> Option<&acp::PermissionOption> {
420        match self {
421            PermissionOptions::Flat(options) => options.iter().find(|option| option.kind == kind),
422            PermissionOptions::Dropdown(options) => options.iter().find_map(|choice| {
423                if choice.allow.kind == kind {
424                    Some(&choice.allow)
425                } else if choice.deny.kind == kind {
426                    Some(&choice.deny)
427                } else {
428                    None
429                }
430            }),
431        }
432    }
433
434    pub fn allow_once_option_id(&self) -> Option<acp::PermissionOptionId> {
435        self.first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
436            .map(|option| option.option_id.clone())
437    }
438}
439
440#[cfg(feature = "test-support")]
441mod test_support {
442    //! Test-only stubs and helpers for acp_thread.
443    //!
444    //! This module is gated by the `test-support` feature and is not included
445    //! in production builds. It provides:
446    //! - `StubAgentConnection` for mocking agent connections in tests
447    //! - `create_test_png_base64` for generating test images
448
449    use std::sync::Arc;
450
451    use action_log::ActionLog;
452    use collections::HashMap;
453    use futures::{channel::oneshot, future::try_join_all};
454    use gpui::{AppContext as _, WeakEntity};
455    use parking_lot::Mutex;
456
457    use super::*;
458
459    /// Creates a PNG image encoded as base64 for testing.
460    ///
461    /// Generates a solid-color PNG of the specified dimensions and returns
462    /// it as a base64-encoded string suitable for use in `ImageContent`.
463    pub fn create_test_png_base64(width: u32, height: u32, color: [u8; 4]) -> String {
464        use image::ImageEncoder as _;
465
466        let mut png_data = Vec::new();
467        {
468            let encoder = image::codecs::png::PngEncoder::new(&mut png_data);
469            let mut pixels = Vec::with_capacity((width * height * 4) as usize);
470            for _ in 0..(width * height) {
471                pixels.extend_from_slice(&color);
472            }
473            encoder
474                .write_image(&pixels, width, height, image::ExtendedColorType::Rgba8)
475                .expect("Failed to encode PNG");
476        }
477
478        use image::EncodableLayout as _;
479        base64::Engine::encode(
480            &base64::engine::general_purpose::STANDARD,
481            png_data.as_bytes(),
482        )
483    }
484
485    #[derive(Clone, Default)]
486    pub struct StubAgentConnection {
487        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
488        permission_requests: HashMap<acp::ToolCallId, PermissionOptions>,
489        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
490    }
491
492    struct Session {
493        thread: WeakEntity<AcpThread>,
494        response_tx: Option<oneshot::Sender<acp::StopReason>>,
495    }
496
497    impl StubAgentConnection {
498        pub fn new() -> Self {
499            Self {
500                next_prompt_updates: Default::default(),
501                permission_requests: HashMap::default(),
502                sessions: Arc::default(),
503            }
504        }
505
506        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
507            *self.next_prompt_updates.lock() = updates;
508        }
509
510        pub fn with_permission_requests(
511            mut self,
512            permission_requests: HashMap<acp::ToolCallId, PermissionOptions>,
513        ) -> Self {
514            self.permission_requests = permission_requests;
515            self
516        }
517
518        pub fn send_update(
519            &self,
520            session_id: acp::SessionId,
521            update: acp::SessionUpdate,
522            cx: &mut App,
523        ) {
524            assert!(
525                self.next_prompt_updates.lock().is_empty(),
526                "Use either send_update or set_next_prompt_updates"
527            );
528
529            self.sessions
530                .lock()
531                .get(&session_id)
532                .unwrap()
533                .thread
534                .update(cx, |thread, cx| {
535                    thread.handle_session_update(update, cx).unwrap();
536                })
537                .unwrap();
538        }
539
540        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
541            self.sessions
542                .lock()
543                .get_mut(&session_id)
544                .unwrap()
545                .response_tx
546                .take()
547                .expect("No pending turn")
548                .send(stop_reason)
549                .unwrap();
550        }
551    }
552
553    impl AgentConnection for StubAgentConnection {
554        fn telemetry_id(&self) -> SharedString {
555            "stub".into()
556        }
557
558        fn auth_methods(&self) -> &[acp::AuthMethod] {
559            &[]
560        }
561
562        fn model_selector(
563            &self,
564            _session_id: &acp::SessionId,
565        ) -> Option<Rc<dyn AgentModelSelector>> {
566            Some(self.model_selector_impl())
567        }
568
569        fn new_thread(
570            self: Rc<Self>,
571            project: Entity<Project>,
572            _cwd: &Path,
573            cx: &mut gpui::App,
574        ) -> Task<gpui::Result<Entity<AcpThread>>> {
575            let session_id = acp::SessionId::new(self.sessions.lock().len().to_string());
576            let action_log = cx.new(|_| ActionLog::new(project.clone()));
577            let thread = cx.new(|cx| {
578                AcpThread::new(
579                    "Test",
580                    self.clone(),
581                    project,
582                    action_log,
583                    session_id.clone(),
584                    watch::Receiver::constant(
585                        acp::PromptCapabilities::new()
586                            .image(true)
587                            .audio(true)
588                            .embedded_context(true),
589                    ),
590                    cx,
591                )
592            });
593            self.sessions.lock().insert(
594                session_id,
595                Session {
596                    thread: thread.downgrade(),
597                    response_tx: None,
598                },
599            );
600            Task::ready(Ok(thread))
601        }
602
603        fn authenticate(
604            &self,
605            _method_id: acp::AuthMethodId,
606            _cx: &mut App,
607        ) -> Task<gpui::Result<()>> {
608            unimplemented!()
609        }
610
611        fn prompt(
612            &self,
613            _id: Option<UserMessageId>,
614            params: acp::PromptRequest,
615            cx: &mut App,
616        ) -> Task<gpui::Result<acp::PromptResponse>> {
617            let mut sessions = self.sessions.lock();
618            let Session {
619                thread,
620                response_tx,
621            } = sessions.get_mut(&params.session_id).unwrap();
622            let mut tasks = vec![];
623            if self.next_prompt_updates.lock().is_empty() {
624                let (tx, rx) = oneshot::channel();
625                response_tx.replace(tx);
626                cx.spawn(async move |_| {
627                    let stop_reason = rx.await?;
628                    Ok(acp::PromptResponse::new(stop_reason))
629                })
630            } else {
631                for update in self.next_prompt_updates.lock().drain(..) {
632                    let thread = thread.clone();
633                    let update = update.clone();
634                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
635                        &update
636                        && let Some(options) = self.permission_requests.get(&tool_call.tool_call_id)
637                    {
638                        Some((tool_call.clone(), options.clone()))
639                    } else {
640                        None
641                    };
642                    let task = cx.spawn(async move |cx| {
643                        if let Some((tool_call, options)) = permission_request {
644                            thread
645                                .update(cx, |thread, cx| {
646                                    thread.request_tool_call_authorization(
647                                        tool_call.clone().into(),
648                                        options.clone(),
649                                        false,
650                                        cx,
651                                    )
652                                })??
653                                .await;
654                        }
655                        thread.update(cx, |thread, cx| {
656                            thread.handle_session_update(update.clone(), cx).unwrap();
657                        })?;
658                        anyhow::Ok(())
659                    });
660                    tasks.push(task);
661                }
662
663                cx.spawn(async move |_| {
664                    try_join_all(tasks).await?;
665                    Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
666                })
667            }
668        }
669
670        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
671            if let Some(end_turn_tx) = self
672                .sessions
673                .lock()
674                .get_mut(session_id)
675                .unwrap()
676                .response_tx
677                .take()
678            {
679                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
680            }
681        }
682
683        fn truncate(
684            &self,
685            _session_id: &agent_client_protocol::SessionId,
686            _cx: &App,
687        ) -> Option<Rc<dyn AgentSessionTruncate>> {
688            Some(Rc::new(StubAgentSessionEditor))
689        }
690
691        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
692            self
693        }
694    }
695
696    struct StubAgentSessionEditor;
697
698    impl AgentSessionTruncate for StubAgentSessionEditor {
699        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
700            Task::ready(Ok(()))
701        }
702    }
703
704    #[derive(Clone)]
705    struct StubModelSelector {
706        selected_model: Arc<Mutex<AgentModelInfo>>,
707    }
708
709    impl StubModelSelector {
710        fn new() -> Self {
711            Self {
712                selected_model: Arc::new(Mutex::new(AgentModelInfo {
713                    id: acp::ModelId::new("visual-test-model"),
714                    name: "Visual Test Model".into(),
715                    description: Some("A stub model for visual testing".into()),
716                    icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
717                })),
718            }
719        }
720    }
721
722    impl AgentModelSelector for StubModelSelector {
723        fn list_models(&self, _cx: &mut App) -> Task<Result<AgentModelList>> {
724            let model = self.selected_model.lock().clone();
725            Task::ready(Ok(AgentModelList::Flat(vec![model])))
726        }
727
728        fn select_model(&self, model_id: acp::ModelId, _cx: &mut App) -> Task<Result<()>> {
729            self.selected_model.lock().id = model_id;
730            Task::ready(Ok(()))
731        }
732
733        fn selected_model(&self, _cx: &mut App) -> Task<Result<AgentModelInfo>> {
734            Task::ready(Ok(self.selected_model.lock().clone()))
735        }
736    }
737
738    impl StubAgentConnection {
739        /// Returns a model selector for this stub connection.
740        pub fn model_selector_impl(&self) -> Rc<dyn AgentModelSelector> {
741            Rc::new(StubModelSelector::new())
742        }
743    }
744}
745
746#[cfg(feature = "test-support")]
747pub use test_support::*;