connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn auth_methods(&self) -> &[acp::AuthMethod];
 31
 32    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 33
 34    fn prompt(
 35        &self,
 36        user_message_id: Option<UserMessageId>,
 37        params: acp::PromptRequest,
 38        cx: &mut App,
 39    ) -> Task<Result<acp::PromptResponse>>;
 40
 41    fn resume(
 42        &self,
 43        _session_id: &acp::SessionId,
 44        _cx: &App,
 45    ) -> Option<Rc<dyn AgentSessionResume>> {
 46        None
 47    }
 48
 49    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 50
 51    fn truncate(
 52        &self,
 53        _session_id: &acp::SessionId,
 54        _cx: &App,
 55    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 56        None
 57    }
 58
 59    fn set_title(
 60        &self,
 61        _session_id: &acp::SessionId,
 62        _cx: &App,
 63    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 64        None
 65    }
 66
 67    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 68    ///
 69    /// If the agent does not support model selection, returns [None].
 70    /// This allows sharing the selector in UI components.
 71    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 72        None
 73    }
 74
 75    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 76        None
 77    }
 78
 79    fn session_modes(
 80        &self,
 81        _session_id: &acp::SessionId,
 82        _cx: &App,
 83    ) -> Option<Rc<dyn AgentSessionModes>> {
 84        None
 85    }
 86
 87    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 88}
 89
 90impl dyn AgentConnection {
 91    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 92        self.into_any().downcast().ok()
 93    }
 94}
 95
 96pub trait AgentSessionTruncate {
 97    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 98}
 99
100pub trait AgentSessionResume {
101    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
102}
103
104pub trait AgentSessionSetTitle {
105    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
106}
107
108pub trait AgentTelemetry {
109    /// The name of the agent used for telemetry.
110    fn agent_name(&self) -> String;
111
112    /// A representation of the current thread state that can be serialized for
113    /// storage with telemetry events.
114    fn thread_data(
115        &self,
116        session_id: &acp::SessionId,
117        cx: &mut App,
118    ) -> Task<Result<serde_json::Value>>;
119}
120
121pub trait AgentSessionModes {
122    fn current_mode(&self) -> acp::SessionModeId;
123
124    fn all_modes(&self) -> Vec<acp::SessionMode>;
125
126    fn set_mode(&self, mode: acp::SessionModeId, cx: &mut App) -> Task<Result<()>>;
127}
128
129#[derive(Debug)]
130pub struct AuthRequired {
131    pub description: Option<String>,
132    pub provider_id: Option<LanguageModelProviderId>,
133}
134
135impl AuthRequired {
136    pub fn new() -> Self {
137        Self {
138            description: None,
139            provider_id: None,
140        }
141    }
142
143    pub fn with_description(mut self, description: String) -> Self {
144        self.description = Some(description);
145        self
146    }
147
148    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
149        self.provider_id = Some(provider_id);
150        self
151    }
152}
153
154impl Error for AuthRequired {}
155impl fmt::Display for AuthRequired {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        write!(f, "Authentication required")
158    }
159}
160
161/// Trait for agents that support listing, selecting, and querying language models.
162///
163/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
164pub trait AgentModelSelector: 'static {
165    /// Lists all available language models for this agent.
166    ///
167    /// # Parameters
168    /// - `cx`: The GPUI app context for async operations and global access.
169    ///
170    /// # Returns
171    /// A task resolving to the list of models or an error (e.g., if no models are configured).
172    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
173
174    /// Selects a model for a specific session (thread).
175    ///
176    /// This sets the default model for future interactions in the session.
177    /// If the session doesn't exist or the model is invalid, it returns an error.
178    ///
179    /// # Parameters
180    /// - `session_id`: The ID of the session (thread) to apply the model to.
181    /// - `model`: The model to select (should be one from [list_models]).
182    /// - `cx`: The GPUI app context.
183    ///
184    /// # Returns
185    /// A task resolving to `Ok(())` on success or an error.
186    fn select_model(
187        &self,
188        session_id: acp::SessionId,
189        model_id: AgentModelId,
190        cx: &mut App,
191    ) -> Task<Result<()>>;
192
193    /// Retrieves the currently selected model for a specific session (thread).
194    ///
195    /// # Parameters
196    /// - `session_id`: The ID of the session (thread) to query.
197    /// - `cx`: The GPUI app context.
198    ///
199    /// # Returns
200    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
201    fn selected_model(
202        &self,
203        session_id: &acp::SessionId,
204        cx: &mut App,
205    ) -> Task<Result<AgentModelInfo>>;
206
207    /// Whenever the model list is updated the receiver will be notified.
208    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
209}
210
211#[derive(Debug, Clone, PartialEq, Eq, Hash)]
212pub struct AgentModelId(pub SharedString);
213
214impl std::ops::Deref for AgentModelId {
215    type Target = SharedString;
216
217    fn deref(&self) -> &Self::Target {
218        &self.0
219    }
220}
221
222impl fmt::Display for AgentModelId {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        self.0.fmt(f)
225    }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq)]
229pub struct AgentModelInfo {
230    pub id: AgentModelId,
231    pub name: SharedString,
232    pub icon: Option<IconName>,
233}
234
235#[derive(Debug, Clone, PartialEq, Eq, Hash)]
236pub struct AgentModelGroupName(pub SharedString);
237
238#[derive(Debug, Clone)]
239pub enum AgentModelList {
240    Flat(Vec<AgentModelInfo>),
241    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
242}
243
244impl AgentModelList {
245    pub fn is_empty(&self) -> bool {
246        match self {
247            AgentModelList::Flat(models) => models.is_empty(),
248            AgentModelList::Grouped(groups) => groups.is_empty(),
249        }
250    }
251}
252
253#[cfg(feature = "test-support")]
254mod test_support {
255    use std::sync::Arc;
256
257    use action_log::ActionLog;
258    use collections::HashMap;
259    use futures::{channel::oneshot, future::try_join_all};
260    use gpui::{AppContext as _, WeakEntity};
261    use parking_lot::Mutex;
262
263    use super::*;
264
265    #[derive(Clone, Default)]
266    pub struct StubAgentConnection {
267        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
268        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
269        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
270    }
271
272    struct Session {
273        thread: WeakEntity<AcpThread>,
274        response_tx: Option<oneshot::Sender<acp::StopReason>>,
275    }
276
277    impl StubAgentConnection {
278        pub fn new() -> Self {
279            Self {
280                next_prompt_updates: Default::default(),
281                permission_requests: HashMap::default(),
282                sessions: Arc::default(),
283            }
284        }
285
286        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
287            *self.next_prompt_updates.lock() = updates;
288        }
289
290        pub fn with_permission_requests(
291            mut self,
292            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
293        ) -> Self {
294            self.permission_requests = permission_requests;
295            self
296        }
297
298        pub fn send_update(
299            &self,
300            session_id: acp::SessionId,
301            update: acp::SessionUpdate,
302            cx: &mut App,
303        ) {
304            assert!(
305                self.next_prompt_updates.lock().is_empty(),
306                "Use either send_update or set_next_prompt_updates"
307            );
308
309            self.sessions
310                .lock()
311                .get(&session_id)
312                .unwrap()
313                .thread
314                .update(cx, |thread, cx| {
315                    thread.handle_session_update(update, cx).unwrap();
316                })
317                .unwrap();
318        }
319
320        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
321            self.sessions
322                .lock()
323                .get_mut(&session_id)
324                .unwrap()
325                .response_tx
326                .take()
327                .expect("No pending turn")
328                .send(stop_reason)
329                .unwrap();
330        }
331    }
332
333    impl AgentConnection for StubAgentConnection {
334        fn auth_methods(&self) -> &[acp::AuthMethod] {
335            &[]
336        }
337
338        fn new_thread(
339            self: Rc<Self>,
340            project: Entity<Project>,
341            _cwd: &Path,
342            cx: &mut gpui::App,
343        ) -> Task<gpui::Result<Entity<AcpThread>>> {
344            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
345            let action_log = cx.new(|_| ActionLog::new(project.clone()));
346            let thread = cx.new(|cx| {
347                AcpThread::new(
348                    "Test",
349                    self.clone(),
350                    project,
351                    action_log,
352                    session_id.clone(),
353                    watch::Receiver::constant(acp::PromptCapabilities {
354                        image: true,
355                        audio: true,
356                        embedded_context: true,
357                        meta: None,
358                    }),
359                    cx,
360                )
361            });
362            self.sessions.lock().insert(
363                session_id,
364                Session {
365                    thread: thread.downgrade(),
366                    response_tx: None,
367                },
368            );
369            Task::ready(Ok(thread))
370        }
371
372        fn authenticate(
373            &self,
374            _method_id: acp::AuthMethodId,
375            _cx: &mut App,
376        ) -> Task<gpui::Result<()>> {
377            unimplemented!()
378        }
379
380        fn prompt(
381            &self,
382            _id: Option<UserMessageId>,
383            params: acp::PromptRequest,
384            cx: &mut App,
385        ) -> Task<gpui::Result<acp::PromptResponse>> {
386            let mut sessions = self.sessions.lock();
387            let Session {
388                thread,
389                response_tx,
390            } = sessions.get_mut(&params.session_id).unwrap();
391            let mut tasks = vec![];
392            if self.next_prompt_updates.lock().is_empty() {
393                let (tx, rx) = oneshot::channel();
394                response_tx.replace(tx);
395                cx.spawn(async move |_| {
396                    let stop_reason = rx.await?;
397                    Ok(acp::PromptResponse {
398                        stop_reason,
399                        meta: None,
400                    })
401                })
402            } else {
403                for update in self.next_prompt_updates.lock().drain(..) {
404                    let thread = thread.clone();
405                    let update = update.clone();
406                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
407                        &update
408                        && let Some(options) = self.permission_requests.get(&tool_call.id)
409                    {
410                        Some((tool_call.clone(), options.clone()))
411                    } else {
412                        None
413                    };
414                    let task = cx.spawn(async move |cx| {
415                        if let Some((tool_call, options)) = permission_request {
416                            thread
417                                .update(cx, |thread, cx| {
418                                    thread.request_tool_call_authorization(
419                                        tool_call.clone().into(),
420                                        options.clone(),
421                                        false,
422                                        cx,
423                                    )
424                                })??
425                                .await;
426                        }
427                        thread.update(cx, |thread, cx| {
428                            thread.handle_session_update(update.clone(), cx).unwrap();
429                        })?;
430                        anyhow::Ok(())
431                    });
432                    tasks.push(task);
433                }
434
435                cx.spawn(async move |_| {
436                    try_join_all(tasks).await?;
437                    Ok(acp::PromptResponse {
438                        stop_reason: acp::StopReason::EndTurn,
439                        meta: None,
440                    })
441                })
442            }
443        }
444
445        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
446            if let Some(end_turn_tx) = self
447                .sessions
448                .lock()
449                .get_mut(session_id)
450                .unwrap()
451                .response_tx
452                .take()
453            {
454                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
455            }
456        }
457
458        fn truncate(
459            &self,
460            _session_id: &agent_client_protocol::SessionId,
461            _cx: &App,
462        ) -> Option<Rc<dyn AgentSessionTruncate>> {
463            Some(Rc::new(StubAgentSessionEditor))
464        }
465
466        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
467            self
468        }
469    }
470
471    struct StubAgentSessionEditor;
472
473    impl AgentSessionTruncate for StubAgentSessionEditor {
474        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
475            Task::ready(Ok(()))
476        }
477    }
478}
479
480#[cfg(feature = "test-support")]
481pub use test_support::*;