connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use chrono::{DateTime, Utc};
  5use collections::IndexMap;
  6use gpui::{Entity, SharedString, Task};
  7use language_model::LanguageModelProviderId;
  8use project::Project;
  9use serde::{Deserialize, Serialize};
 10use std::{
 11    any::Any,
 12    error::Error,
 13    fmt,
 14    path::{Path, PathBuf},
 15    rc::Rc,
 16    sync::Arc,
 17};
 18use ui::{App, IconName};
 19use uuid::Uuid;
 20
 21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 22pub struct UserMessageId(Arc<str>);
 23
 24impl UserMessageId {
 25    pub fn new() -> Self {
 26        Self(Uuid::new_v4().to_string().into())
 27    }
 28}
 29
 30pub trait AgentConnection {
 31    fn telemetry_id(&self) -> SharedString;
 32
 33    fn new_thread(
 34        self: Rc<Self>,
 35        project: Entity<Project>,
 36        cwd: &Path,
 37        cx: &mut App,
 38    ) -> Task<Result<Entity<AcpThread>>>;
 39
 40    /// Whether this agent supports loading existing sessions.
 41    fn supports_load_session(&self, _cx: &App) -> bool {
 42        false
 43    }
 44
 45    /// Load an existing session by ID.
 46    fn load_session(
 47        self: Rc<Self>,
 48        _session: AgentSessionInfo,
 49        _project: Entity<Project>,
 50        _cwd: &Path,
 51        _cx: &mut App,
 52    ) -> Task<Result<Entity<AcpThread>>> {
 53        Task::ready(Err(anyhow::Error::msg("Loading sessions is not supported")))
 54    }
 55
 56    fn auth_methods(&self) -> &[acp::AuthMethod];
 57
 58    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 59
 60    fn prompt(
 61        &self,
 62        user_message_id: Option<UserMessageId>,
 63        params: acp::PromptRequest,
 64        cx: &mut App,
 65    ) -> Task<Result<acp::PromptResponse>>;
 66
 67    fn retry(&self, _session_id: &acp::SessionId, _cx: &App) -> Option<Rc<dyn AgentSessionRetry>> {
 68        None
 69    }
 70
 71    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 72
 73    fn truncate(
 74        &self,
 75        _session_id: &acp::SessionId,
 76        _cx: &App,
 77    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 78        None
 79    }
 80
 81    fn set_title(
 82        &self,
 83        _session_id: &acp::SessionId,
 84        _cx: &App,
 85    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 86        None
 87    }
 88
 89    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 90    ///
 91    /// If the agent does not support model selection, returns [None].
 92    /// This allows sharing the selector in UI components.
 93    fn model_selector(&self, _session_id: &acp::SessionId) -> Option<Rc<dyn AgentModelSelector>> {
 94        None
 95    }
 96
 97    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 98        None
 99    }
100
101    fn session_modes(
102        &self,
103        _session_id: &acp::SessionId,
104        _cx: &App,
105    ) -> Option<Rc<dyn AgentSessionModes>> {
106        None
107    }
108
109    fn session_config_options(
110        &self,
111        _session_id: &acp::SessionId,
112        _cx: &App,
113    ) -> Option<Rc<dyn AgentSessionConfigOptions>> {
114        None
115    }
116
117    fn session_list(&self, _cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
118        None
119    }
120
121    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
122}
123
124impl dyn AgentConnection {
125    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
126        self.into_any().downcast().ok()
127    }
128}
129
130pub trait AgentSessionTruncate {
131    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
132}
133
134pub trait AgentSessionRetry {
135    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
136}
137
138pub trait AgentSessionSetTitle {
139    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
140}
141
142pub trait AgentTelemetry {
143    /// A representation of the current thread state that can be serialized for
144    /// storage with telemetry events.
145    fn thread_data(
146        &self,
147        session_id: &acp::SessionId,
148        cx: &mut App,
149    ) -> Task<Result<serde_json::Value>>;
150}
151
152pub trait AgentSessionModes {
153    fn current_mode(&self) -> acp::SessionModeId;
154
155    fn all_modes(&self) -> Vec<acp::SessionMode>;
156
157    fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
158}
159
160pub trait AgentSessionConfigOptions {
161    /// Get all current config options with their state
162    fn config_options(&self) -> Vec<acp::SessionConfigOption>;
163
164    /// Set a config option value
165    /// Returns the full updated list of config options
166    fn set_config_option(
167        &self,
168        config_id: acp::SessionConfigId,
169        value: acp::SessionConfigValueId,
170        cx: &mut App,
171    ) -> Task<Result<Vec<acp::SessionConfigOption>>>;
172
173    /// Whenever the config options are updated the receiver will be notified.
174    /// Optional for agents that don't update their config options dynamically.
175    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
176        None
177    }
178}
179
180#[derive(Debug, Clone, Default)]
181pub struct AgentSessionListRequest {
182    pub cwd: Option<PathBuf>,
183    pub cursor: Option<String>,
184    pub meta: Option<acp::Meta>,
185}
186
187#[derive(Debug, Clone)]
188pub struct AgentSessionListResponse {
189    pub sessions: Vec<AgentSessionInfo>,
190    pub next_cursor: Option<String>,
191    pub meta: Option<acp::Meta>,
192}
193
194impl AgentSessionListResponse {
195    pub fn new(sessions: Vec<AgentSessionInfo>) -> Self {
196        Self {
197            sessions,
198            next_cursor: None,
199            meta: None,
200        }
201    }
202}
203
204#[derive(Debug, Clone, PartialEq)]
205pub struct AgentSessionInfo {
206    pub session_id: acp::SessionId,
207    pub cwd: Option<PathBuf>,
208    pub title: Option<SharedString>,
209    pub updated_at: Option<DateTime<Utc>>,
210    pub meta: Option<acp::Meta>,
211}
212
213impl AgentSessionInfo {
214    pub fn new(session_id: impl Into<acp::SessionId>) -> Self {
215        Self {
216            session_id: session_id.into(),
217            cwd: None,
218            title: None,
219            updated_at: None,
220            meta: None,
221        }
222    }
223}
224
225#[derive(Debug, Clone)]
226pub enum SessionListUpdate {
227    Refresh,
228    SessionInfo {
229        session_id: acp::SessionId,
230        update: acp::SessionInfoUpdate,
231    },
232}
233
234pub trait AgentSessionList {
235    fn list_sessions(
236        &self,
237        request: AgentSessionListRequest,
238        cx: &mut App,
239    ) -> Task<Result<AgentSessionListResponse>>;
240
241    fn supports_delete(&self) -> bool {
242        false
243    }
244
245    fn delete_session(&self, _session_id: &acp::SessionId, _cx: &mut App) -> Task<Result<()>> {
246        Task::ready(Err(anyhow::anyhow!("delete_session not supported")))
247    }
248
249    fn delete_sessions(&self, _cx: &mut App) -> Task<Result<()>> {
250        Task::ready(Err(anyhow::anyhow!("delete_sessions not supported")))
251    }
252
253    fn watch(&self, _cx: &mut App) -> Option<smol::channel::Receiver<SessionListUpdate>> {
254        None
255    }
256
257    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
258}
259
260impl dyn AgentSessionList {
261    pub fn downcast<T: 'static + AgentSessionList + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
262        self.into_any().downcast().ok()
263    }
264}
265
266#[derive(Debug)]
267pub struct AuthRequired {
268    pub description: Option<String>,
269    pub provider_id: Option<LanguageModelProviderId>,
270}
271
272impl AuthRequired {
273    pub fn new() -> Self {
274        Self {
275            description: None,
276            provider_id: None,
277        }
278    }
279
280    pub fn with_description(mut self, description: String) -> Self {
281        self.description = Some(description);
282        self
283    }
284
285    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
286        self.provider_id = Some(provider_id);
287        self
288    }
289}
290
291impl Error for AuthRequired {}
292impl fmt::Display for AuthRequired {
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        write!(f, "Authentication required")
295    }
296}
297
298/// Trait for agents that support listing, selecting, and querying language models.
299///
300/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
301pub trait AgentModelSelector: 'static {
302    /// Lists all available language models for this agent.
303    ///
304    /// # Parameters
305    /// - `cx`: The GPUI app context for async operations and global access.
306    ///
307    /// # Returns
308    /// A task resolving to the list of models or an error (e.g., if no models are configured).
309    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
310
311    /// Selects a model for a specific session (thread).
312    ///
313    /// This sets the default model for future interactions in the session.
314    /// If the session doesn't exist or the model is invalid, it returns an error.
315    ///
316    /// # Parameters
317    /// - `model`: The model to select (should be one from [list_models]).
318    /// - `cx`: The GPUI app context.
319    ///
320    /// # Returns
321    /// A task resolving to `Ok(())` on success or an error.
322    fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task<Result<()>>;
323
324    /// Retrieves the currently selected model for a specific session (thread).
325    ///
326    /// # Parameters
327    /// - `cx`: The GPUI app context.
328    ///
329    /// # Returns
330    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
331    fn selected_model(&self, cx: &mut App) -> Task<Result<AgentModelInfo>>;
332
333    /// Whenever the model list is updated the receiver will be notified.
334    /// Optional for agents that don't update their model list.
335    fn watch(&self, _cx: &mut App) -> Option<watch::Receiver<()>> {
336        None
337    }
338
339    /// Returns whether the model picker should render a footer.
340    fn should_render_footer(&self) -> bool {
341        false
342    }
343}
344
345/// Icon for a model in the model selector.
346#[derive(Debug, Clone, PartialEq, Eq)]
347pub enum AgentModelIcon {
348    /// A built-in icon from Zed's icon set.
349    Named(IconName),
350    /// Path to a custom SVG icon file.
351    Path(SharedString),
352}
353
354#[derive(Debug, Clone, PartialEq, Eq)]
355pub struct AgentModelInfo {
356    pub id: acp::ModelId,
357    pub name: SharedString,
358    pub description: Option<SharedString>,
359    pub icon: Option<AgentModelIcon>,
360}
361
362impl From<acp::ModelInfo> for AgentModelInfo {
363    fn from(info: acp::ModelInfo) -> Self {
364        Self {
365            id: info.model_id,
366            name: info.name.into(),
367            description: info.description.map(|desc| desc.into()),
368            icon: None,
369        }
370    }
371}
372
373#[derive(Debug, Clone, PartialEq, Eq, Hash)]
374pub struct AgentModelGroupName(pub SharedString);
375
376#[derive(Debug, Clone)]
377pub enum AgentModelList {
378    Flat(Vec<AgentModelInfo>),
379    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
380}
381
382impl AgentModelList {
383    pub fn is_empty(&self) -> bool {
384        match self {
385            AgentModelList::Flat(models) => models.is_empty(),
386            AgentModelList::Grouped(groups) => groups.is_empty(),
387        }
388    }
389
390    pub fn is_flat(&self) -> bool {
391        matches!(self, AgentModelList::Flat(_))
392    }
393}
394
395#[derive(Debug, Clone)]
396pub struct PermissionOptionChoice {
397    pub allow: acp::PermissionOption,
398    pub deny: acp::PermissionOption,
399}
400
401impl PermissionOptionChoice {
402    pub fn label(&self) -> SharedString {
403        self.allow.name.clone().into()
404    }
405}
406
407#[derive(Debug, Clone)]
408pub enum PermissionOptions {
409    Flat(Vec<acp::PermissionOption>),
410    Dropdown(Vec<PermissionOptionChoice>),
411}
412
413impl PermissionOptions {
414    pub fn is_empty(&self) -> bool {
415        match self {
416            PermissionOptions::Flat(options) => options.is_empty(),
417            PermissionOptions::Dropdown(options) => options.is_empty(),
418        }
419    }
420
421    pub fn first_option_of_kind(
422        &self,
423        kind: acp::PermissionOptionKind,
424    ) -> Option<&acp::PermissionOption> {
425        match self {
426            PermissionOptions::Flat(options) => options.iter().find(|option| option.kind == kind),
427            PermissionOptions::Dropdown(options) => options.iter().find_map(|choice| {
428                if choice.allow.kind == kind {
429                    Some(&choice.allow)
430                } else if choice.deny.kind == kind {
431                    Some(&choice.deny)
432                } else {
433                    None
434                }
435            }),
436        }
437    }
438
439    pub fn allow_once_option_id(&self) -> Option<acp::PermissionOptionId> {
440        self.first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
441            .map(|option| option.option_id.clone())
442    }
443}
444
445#[cfg(feature = "test-support")]
446mod test_support {
447    //! Test-only stubs and helpers for acp_thread.
448    //!
449    //! This module is gated by the `test-support` feature and is not included
450    //! in production builds. It provides:
451    //! - `StubAgentConnection` for mocking agent connections in tests
452    //! - `create_test_png_base64` for generating test images
453
454    use std::sync::Arc;
455
456    use action_log::ActionLog;
457    use collections::HashMap;
458    use futures::{channel::oneshot, future::try_join_all};
459    use gpui::{AppContext as _, WeakEntity};
460    use parking_lot::Mutex;
461
462    use super::*;
463
464    /// Creates a PNG image encoded as base64 for testing.
465    ///
466    /// Generates a solid-color PNG of the specified dimensions and returns
467    /// it as a base64-encoded string suitable for use in `ImageContent`.
468    pub fn create_test_png_base64(width: u32, height: u32, color: [u8; 4]) -> String {
469        use image::ImageEncoder as _;
470
471        let mut png_data = Vec::new();
472        {
473            let encoder = image::codecs::png::PngEncoder::new(&mut png_data);
474            let mut pixels = Vec::with_capacity((width * height * 4) as usize);
475            for _ in 0..(width * height) {
476                pixels.extend_from_slice(&color);
477            }
478            encoder
479                .write_image(&pixels, width, height, image::ExtendedColorType::Rgba8)
480                .expect("Failed to encode PNG");
481        }
482
483        use image::EncodableLayout as _;
484        base64::Engine::encode(
485            &base64::engine::general_purpose::STANDARD,
486            png_data.as_bytes(),
487        )
488    }
489
490    #[derive(Clone, Default)]
491    pub struct StubAgentConnection {
492        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
493        permission_requests: HashMap<acp::ToolCallId, PermissionOptions>,
494        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
495    }
496
497    struct Session {
498        thread: WeakEntity<AcpThread>,
499        response_tx: Option<oneshot::Sender<acp::StopReason>>,
500    }
501
502    impl StubAgentConnection {
503        pub fn new() -> Self {
504            Self {
505                next_prompt_updates: Default::default(),
506                permission_requests: HashMap::default(),
507                sessions: Arc::default(),
508            }
509        }
510
511        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
512            *self.next_prompt_updates.lock() = updates;
513        }
514
515        pub fn with_permission_requests(
516            mut self,
517            permission_requests: HashMap<acp::ToolCallId, PermissionOptions>,
518        ) -> Self {
519            self.permission_requests = permission_requests;
520            self
521        }
522
523        pub fn send_update(
524            &self,
525            session_id: acp::SessionId,
526            update: acp::SessionUpdate,
527            cx: &mut App,
528        ) {
529            assert!(
530                self.next_prompt_updates.lock().is_empty(),
531                "Use either send_update or set_next_prompt_updates"
532            );
533
534            self.sessions
535                .lock()
536                .get(&session_id)
537                .unwrap()
538                .thread
539                .update(cx, |thread, cx| {
540                    thread.handle_session_update(update, cx).unwrap();
541                })
542                .unwrap();
543        }
544
545        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
546            self.sessions
547                .lock()
548                .get_mut(&session_id)
549                .unwrap()
550                .response_tx
551                .take()
552                .expect("No pending turn")
553                .send(stop_reason)
554                .unwrap();
555        }
556    }
557
558    impl AgentConnection for StubAgentConnection {
559        fn telemetry_id(&self) -> SharedString {
560            "stub".into()
561        }
562
563        fn auth_methods(&self) -> &[acp::AuthMethod] {
564            &[]
565        }
566
567        fn model_selector(
568            &self,
569            _session_id: &acp::SessionId,
570        ) -> Option<Rc<dyn AgentModelSelector>> {
571            Some(self.model_selector_impl())
572        }
573
574        fn new_thread(
575            self: Rc<Self>,
576            project: Entity<Project>,
577            _cwd: &Path,
578            cx: &mut gpui::App,
579        ) -> Task<gpui::Result<Entity<AcpThread>>> {
580            let session_id = acp::SessionId::new(self.sessions.lock().len().to_string());
581            let action_log = cx.new(|_| ActionLog::new(project.clone()));
582            let thread = cx.new(|cx| {
583                AcpThread::new(
584                    "Test",
585                    self.clone(),
586                    project,
587                    action_log,
588                    session_id.clone(),
589                    watch::Receiver::constant(
590                        acp::PromptCapabilities::new()
591                            .image(true)
592                            .audio(true)
593                            .embedded_context(true),
594                    ),
595                    cx,
596                )
597            });
598            self.sessions.lock().insert(
599                session_id,
600                Session {
601                    thread: thread.downgrade(),
602                    response_tx: None,
603                },
604            );
605            Task::ready(Ok(thread))
606        }
607
608        fn authenticate(
609            &self,
610            _method_id: acp::AuthMethodId,
611            _cx: &mut App,
612        ) -> Task<gpui::Result<()>> {
613            unimplemented!()
614        }
615
616        fn prompt(
617            &self,
618            _id: Option<UserMessageId>,
619            params: acp::PromptRequest,
620            cx: &mut App,
621        ) -> Task<gpui::Result<acp::PromptResponse>> {
622            let mut sessions = self.sessions.lock();
623            let Session {
624                thread,
625                response_tx,
626            } = sessions.get_mut(&params.session_id).unwrap();
627            let mut tasks = vec![];
628            if self.next_prompt_updates.lock().is_empty() {
629                let (tx, rx) = oneshot::channel();
630                response_tx.replace(tx);
631                cx.spawn(async move |_| {
632                    let stop_reason = rx.await?;
633                    Ok(acp::PromptResponse::new(stop_reason))
634                })
635            } else {
636                for update in self.next_prompt_updates.lock().drain(..) {
637                    let thread = thread.clone();
638                    let update = update.clone();
639                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
640                        &update
641                        && let Some(options) = self.permission_requests.get(&tool_call.tool_call_id)
642                    {
643                        Some((tool_call.clone(), options.clone()))
644                    } else {
645                        None
646                    };
647                    let task = cx.spawn(async move |cx| {
648                        if let Some((tool_call, options)) = permission_request {
649                            thread
650                                .update(cx, |thread, cx| {
651                                    thread.request_tool_call_authorization(
652                                        tool_call.clone().into(),
653                                        options.clone(),
654                                        false,
655                                        cx,
656                                    )
657                                })??
658                                .await;
659                        }
660                        thread.update(cx, |thread, cx| {
661                            thread.handle_session_update(update.clone(), cx).unwrap();
662                        })?;
663                        anyhow::Ok(())
664                    });
665                    tasks.push(task);
666                }
667
668                cx.spawn(async move |_| {
669                    try_join_all(tasks).await?;
670                    Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
671                })
672            }
673        }
674
675        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
676            if let Some(end_turn_tx) = self
677                .sessions
678                .lock()
679                .get_mut(session_id)
680                .unwrap()
681                .response_tx
682                .take()
683            {
684                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
685            }
686        }
687
688        fn truncate(
689            &self,
690            _session_id: &agent_client_protocol::SessionId,
691            _cx: &App,
692        ) -> Option<Rc<dyn AgentSessionTruncate>> {
693            Some(Rc::new(StubAgentSessionEditor))
694        }
695
696        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
697            self
698        }
699    }
700
701    struct StubAgentSessionEditor;
702
703    impl AgentSessionTruncate for StubAgentSessionEditor {
704        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
705            Task::ready(Ok(()))
706        }
707    }
708
709    #[derive(Clone)]
710    struct StubModelSelector {
711        selected_model: Arc<Mutex<AgentModelInfo>>,
712    }
713
714    impl StubModelSelector {
715        fn new() -> Self {
716            Self {
717                selected_model: Arc::new(Mutex::new(AgentModelInfo {
718                    id: acp::ModelId::new("visual-test-model"),
719                    name: "Visual Test Model".into(),
720                    description: Some("A stub model for visual testing".into()),
721                    icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
722                })),
723            }
724        }
725    }
726
727    impl AgentModelSelector for StubModelSelector {
728        fn list_models(&self, _cx: &mut App) -> Task<Result<AgentModelList>> {
729            let model = self.selected_model.lock().clone();
730            Task::ready(Ok(AgentModelList::Flat(vec![model])))
731        }
732
733        fn select_model(&self, model_id: acp::ModelId, _cx: &mut App) -> Task<Result<()>> {
734            self.selected_model.lock().id = model_id;
735            Task::ready(Ok(()))
736        }
737
738        fn selected_model(&self, _cx: &mut App) -> Task<Result<AgentModelInfo>> {
739            Task::ready(Ok(self.selected_model.lock().clone()))
740        }
741    }
742
743    impl StubAgentConnection {
744        /// Returns a model selector for this stub connection.
745        pub fn model_selector_impl(&self) -> Rc<dyn AgentModelSelector> {
746            Rc::new(StubModelSelector::new())
747        }
748    }
749}
750
751#[cfg(feature = "test-support")]
752pub use test_support::*;