connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn auth_methods(&self) -> &[acp::AuthMethod];
 31
 32    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 33
 34    fn prompt(
 35        &self,
 36        user_message_id: Option<UserMessageId>,
 37        params: acp::PromptRequest,
 38        cx: &mut App,
 39    ) -> Task<Result<acp::PromptResponse>>;
 40
 41    fn resume(
 42        &self,
 43        _session_id: &acp::SessionId,
 44        _cx: &App,
 45    ) -> Option<Rc<dyn AgentSessionResume>> {
 46        None
 47    }
 48
 49    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 50
 51    fn truncate(
 52        &self,
 53        _session_id: &acp::SessionId,
 54        _cx: &App,
 55    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 56        None
 57    }
 58
 59    fn set_title(
 60        &self,
 61        _session_id: &acp::SessionId,
 62        _cx: &App,
 63    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 64        None
 65    }
 66
 67    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 68    ///
 69    /// If the agent does not support model selection, returns [None].
 70    /// This allows sharing the selector in UI components.
 71    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 72        None
 73    }
 74
 75    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 76        None
 77    }
 78    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 79}
 80
 81impl dyn AgentConnection {
 82    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 83        self.into_any().downcast().ok()
 84    }
 85}
 86
 87pub trait AgentSessionTruncate {
 88    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 89}
 90
 91pub trait AgentSessionResume {
 92    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 93}
 94
 95pub trait AgentSessionSetTitle {
 96    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
 97}
 98
 99pub trait AgentTelemetry {
100    /// The name of the agent used for telemetry.
101    fn agent_name(&self) -> String;
102
103    /// A representation of the current thread state that can be serialized for
104    /// storage with telemetry events.
105    fn thread_data(
106        &self,
107        session_id: &acp::SessionId,
108        cx: &mut App,
109    ) -> Task<Result<serde_json::Value>>;
110}
111
112#[derive(Debug)]
113pub struct AuthRequired {
114    pub description: Option<String>,
115    pub provider_id: Option<LanguageModelProviderId>,
116}
117
118impl AuthRequired {
119    pub fn new() -> Self {
120        Self {
121            description: None,
122            provider_id: None,
123        }
124    }
125
126    pub fn with_description(mut self, description: String) -> Self {
127        self.description = Some(description);
128        self
129    }
130
131    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
132        self.provider_id = Some(provider_id);
133        self
134    }
135}
136
137impl Error for AuthRequired {}
138impl fmt::Display for AuthRequired {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        write!(f, "Authentication required")
141    }
142}
143
144/// Trait for agents that support listing, selecting, and querying language models.
145///
146/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
147pub trait AgentModelSelector: 'static {
148    /// Lists all available language models for this agent.
149    ///
150    /// # Parameters
151    /// - `cx`: The GPUI app context for async operations and global access.
152    ///
153    /// # Returns
154    /// A task resolving to the list of models or an error (e.g., if no models are configured).
155    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
156
157    /// Selects a model for a specific session (thread).
158    ///
159    /// This sets the default model for future interactions in the session.
160    /// If the session doesn't exist or the model is invalid, it returns an error.
161    ///
162    /// # Parameters
163    /// - `session_id`: The ID of the session (thread) to apply the model to.
164    /// - `model`: The model to select (should be one from [list_models]).
165    /// - `cx`: The GPUI app context.
166    ///
167    /// # Returns
168    /// A task resolving to `Ok(())` on success or an error.
169    fn select_model(
170        &self,
171        session_id: acp::SessionId,
172        model_id: AgentModelId,
173        cx: &mut App,
174    ) -> Task<Result<()>>;
175
176    /// Retrieves the currently selected model for a specific session (thread).
177    ///
178    /// # Parameters
179    /// - `session_id`: The ID of the session (thread) to query.
180    /// - `cx`: The GPUI app context.
181    ///
182    /// # Returns
183    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
184    fn selected_model(
185        &self,
186        session_id: &acp::SessionId,
187        cx: &mut App,
188    ) -> Task<Result<AgentModelInfo>>;
189
190    /// Whenever the model list is updated the receiver will be notified.
191    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, Hash)]
195pub struct AgentModelId(pub SharedString);
196
197impl std::ops::Deref for AgentModelId {
198    type Target = SharedString;
199
200    fn deref(&self) -> &Self::Target {
201        &self.0
202    }
203}
204
205impl fmt::Display for AgentModelId {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        self.0.fmt(f)
208    }
209}
210
211#[derive(Debug, Clone, PartialEq, Eq)]
212pub struct AgentModelInfo {
213    pub id: AgentModelId,
214    pub name: SharedString,
215    pub icon: Option<IconName>,
216}
217
218#[derive(Debug, Clone, PartialEq, Eq, Hash)]
219pub struct AgentModelGroupName(pub SharedString);
220
221#[derive(Debug, Clone)]
222pub enum AgentModelList {
223    Flat(Vec<AgentModelInfo>),
224    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
225}
226
227impl AgentModelList {
228    pub fn is_empty(&self) -> bool {
229        match self {
230            AgentModelList::Flat(models) => models.is_empty(),
231            AgentModelList::Grouped(groups) => groups.is_empty(),
232        }
233    }
234}
235
236#[cfg(feature = "test-support")]
237mod test_support {
238    use std::sync::Arc;
239
240    use action_log::ActionLog;
241    use collections::HashMap;
242    use futures::{channel::oneshot, future::try_join_all};
243    use gpui::{AppContext as _, WeakEntity};
244    use parking_lot::Mutex;
245
246    use super::*;
247
248    #[derive(Clone, Default)]
249    pub struct StubAgentConnection {
250        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
251        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
252        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
253    }
254
255    struct Session {
256        thread: WeakEntity<AcpThread>,
257        response_tx: Option<oneshot::Sender<acp::StopReason>>,
258    }
259
260    impl StubAgentConnection {
261        pub fn new() -> Self {
262            Self {
263                next_prompt_updates: Default::default(),
264                permission_requests: HashMap::default(),
265                sessions: Arc::default(),
266            }
267        }
268
269        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
270            *self.next_prompt_updates.lock() = updates;
271        }
272
273        pub fn with_permission_requests(
274            mut self,
275            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
276        ) -> Self {
277            self.permission_requests = permission_requests;
278            self
279        }
280
281        pub fn send_update(
282            &self,
283            session_id: acp::SessionId,
284            update: acp::SessionUpdate,
285            cx: &mut App,
286        ) {
287            assert!(
288                self.next_prompt_updates.lock().is_empty(),
289                "Use either send_update or set_next_prompt_updates"
290            );
291
292            self.sessions
293                .lock()
294                .get(&session_id)
295                .unwrap()
296                .thread
297                .update(cx, |thread, cx| {
298                    thread.handle_session_update(update, cx).unwrap();
299                })
300                .unwrap();
301        }
302
303        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
304            self.sessions
305                .lock()
306                .get_mut(&session_id)
307                .unwrap()
308                .response_tx
309                .take()
310                .expect("No pending turn")
311                .send(stop_reason)
312                .unwrap();
313        }
314    }
315
316    impl AgentConnection for StubAgentConnection {
317        fn auth_methods(&self) -> &[acp::AuthMethod] {
318            &[]
319        }
320
321        fn new_thread(
322            self: Rc<Self>,
323            project: Entity<Project>,
324            _cwd: &Path,
325            cx: &mut gpui::App,
326        ) -> Task<gpui::Result<Entity<AcpThread>>> {
327            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
328            let action_log = cx.new(|_| ActionLog::new(project.clone()));
329            let thread = cx.new(|cx| {
330                AcpThread::new(
331                    "Test",
332                    self.clone(),
333                    project,
334                    action_log,
335                    session_id.clone(),
336                    watch::Receiver::constant(acp::PromptCapabilities {
337                        image: true,
338                        audio: true,
339                        embedded_context: true,
340                    }),
341                    cx,
342                )
343            });
344            self.sessions.lock().insert(
345                session_id,
346                Session {
347                    thread: thread.downgrade(),
348                    response_tx: None,
349                },
350            );
351            Task::ready(Ok(thread))
352        }
353
354        fn authenticate(
355            &self,
356            _method_id: acp::AuthMethodId,
357            _cx: &mut App,
358        ) -> Task<gpui::Result<()>> {
359            unimplemented!()
360        }
361
362        fn prompt(
363            &self,
364            _id: Option<UserMessageId>,
365            params: acp::PromptRequest,
366            cx: &mut App,
367        ) -> Task<gpui::Result<acp::PromptResponse>> {
368            let mut sessions = self.sessions.lock();
369            let Session {
370                thread,
371                response_tx,
372            } = sessions.get_mut(&params.session_id).unwrap();
373            let mut tasks = vec![];
374            if self.next_prompt_updates.lock().is_empty() {
375                let (tx, rx) = oneshot::channel();
376                response_tx.replace(tx);
377                cx.spawn(async move |_| {
378                    let stop_reason = rx.await?;
379                    Ok(acp::PromptResponse { stop_reason })
380                })
381            } else {
382                for update in self.next_prompt_updates.lock().drain(..) {
383                    let thread = thread.clone();
384                    let update = update.clone();
385                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
386                        &update
387                        && let Some(options) = self.permission_requests.get(&tool_call.id)
388                    {
389                        Some((tool_call.clone(), options.clone()))
390                    } else {
391                        None
392                    };
393                    let task = cx.spawn(async move |cx| {
394                        if let Some((tool_call, options)) = permission_request {
395                            thread
396                                .update(cx, |thread, cx| {
397                                    thread.request_tool_call_authorization(
398                                        tool_call.clone().into(),
399                                        options.clone(),
400                                        cx,
401                                    )
402                                })??
403                                .await;
404                        }
405                        thread.update(cx, |thread, cx| {
406                            thread.handle_session_update(update.clone(), cx).unwrap();
407                        })?;
408                        anyhow::Ok(())
409                    });
410                    tasks.push(task);
411                }
412
413                cx.spawn(async move |_| {
414                    try_join_all(tasks).await?;
415                    Ok(acp::PromptResponse {
416                        stop_reason: acp::StopReason::EndTurn,
417                    })
418                })
419            }
420        }
421
422        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
423            if let Some(end_turn_tx) = self
424                .sessions
425                .lock()
426                .get_mut(session_id)
427                .unwrap()
428                .response_tx
429                .take()
430            {
431                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
432            }
433        }
434
435        fn truncate(
436            &self,
437            _session_id: &agent_client_protocol::SessionId,
438            _cx: &App,
439        ) -> Option<Rc<dyn AgentSessionTruncate>> {
440            Some(Rc::new(StubAgentSessionEditor))
441        }
442
443        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
444            self
445        }
446    }
447
448    struct StubAgentSessionEditor;
449
450    impl AgentSessionTruncate for StubAgentSessionEditor {
451        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
452            Task::ready(Ok(()))
453        }
454    }
455}
456
457#[cfg(feature = "test-support")]
458pub use test_support::*;