connection.rs

  1use crate::{AcpThread, AcpThreadMetadata};
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use project::Project;
  7use serde::{Deserialize, Serialize};
  8use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
  9use ui::{App, IconName};
 10use uuid::Uuid;
 11
 12#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 13pub struct UserMessageId(Arc<str>);
 14
 15impl UserMessageId {
 16    pub fn new() -> Self {
 17        Self(Uuid::new_v4().to_string().into())
 18    }
 19}
 20
 21pub trait AgentConnection {
 22    fn new_thread(
 23        self: Rc<Self>,
 24        project: Entity<Project>,
 25        cwd: &Path,
 26        cx: &mut App,
 27    ) -> Task<Result<Entity<AcpThread>>>;
 28
 29    // todo!(expose a history trait, and include list_threads and load_thread)
 30    // todo!(write a test)
 31    fn list_threads(
 32        &self,
 33        _cx: &mut App,
 34    ) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
 35        return None;
 36    }
 37
 38    fn load_thread(
 39        self: Rc<Self>,
 40        _project: Entity<Project>,
 41        _cwd: &Path,
 42        _session_id: acp::SessionId,
 43        _cx: &mut App,
 44    ) -> Task<Result<Entity<AcpThread>>> {
 45        Task::ready(Err(anyhow::anyhow!("load thread not implemented")))
 46    }
 47
 48    fn auth_methods(&self) -> &[acp::AuthMethod];
 49
 50    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 51
 52    fn prompt(
 53        &self,
 54        user_message_id: Option<UserMessageId>,
 55        params: acp::PromptRequest,
 56        cx: &mut App,
 57    ) -> Task<Result<acp::PromptResponse>>;
 58
 59    fn resume(
 60        &self,
 61        _session_id: &acp::SessionId,
 62        _cx: &mut App,
 63    ) -> Option<Rc<dyn AgentSessionResume>> {
 64        None
 65    }
 66
 67    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 68
 69    fn session_editor(
 70        &self,
 71        _session_id: &acp::SessionId,
 72        _cx: &mut App,
 73    ) -> Option<Rc<dyn AgentSessionEditor>> {
 74        None
 75    }
 76
 77    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 78    ///
 79    /// If the agent does not support model selection, returns [None].
 80    /// This allows sharing the selector in UI components.
 81    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 82        None
 83    }
 84
 85    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 86}
 87
 88impl dyn AgentConnection {
 89    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 90        self.into_any().downcast().ok()
 91    }
 92}
 93
 94pub trait AgentSessionEditor {
 95    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 96}
 97
 98pub trait AgentSessionResume {
 99    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
100}
101
102#[derive(Debug)]
103pub struct AuthRequired;
104
105impl Error for AuthRequired {}
106impl fmt::Display for AuthRequired {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        write!(f, "AuthRequired")
109    }
110}
111
112/// Trait for agents that support listing, selecting, and querying language models.
113///
114/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
115pub trait AgentModelSelector: 'static {
116    /// Lists all available language models for this agent.
117    ///
118    /// # Parameters
119    /// - `cx`: The GPUI app context for async operations and global access.
120    ///
121    /// # Returns
122    /// A task resolving to the list of models or an error (e.g., if no models are configured).
123    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
124
125    /// Selects a model for a specific session (thread).
126    ///
127    /// This sets the default model for future interactions in the session.
128    /// If the session doesn't exist or the model is invalid, it returns an error.
129    ///
130    /// # Parameters
131    /// - `session_id`: The ID of the session (thread) to apply the model to.
132    /// - `model`: The model to select (should be one from [list_models]).
133    /// - `cx`: The GPUI app context.
134    ///
135    /// # Returns
136    /// A task resolving to `Ok(())` on success or an error.
137    fn select_model(
138        &self,
139        session_id: acp::SessionId,
140        model_id: AgentModelId,
141        cx: &mut App,
142    ) -> Task<Result<()>>;
143
144    /// Retrieves the currently selected model for a specific session (thread).
145    ///
146    /// # Parameters
147    /// - `session_id`: The ID of the session (thread) to query.
148    /// - `cx`: The GPUI app context.
149    ///
150    /// # Returns
151    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
152    fn selected_model(
153        &self,
154        session_id: &acp::SessionId,
155        cx: &mut App,
156    ) -> Task<Result<AgentModelInfo>>;
157
158    /// Whenever the model list is updated the receiver will be notified.
159    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
160}
161
162#[derive(Debug, Clone, PartialEq, Eq, Hash)]
163pub struct AgentModelId(pub SharedString);
164
165impl std::ops::Deref for AgentModelId {
166    type Target = SharedString;
167
168    fn deref(&self) -> &Self::Target {
169        &self.0
170    }
171}
172
173impl fmt::Display for AgentModelId {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        self.0.fmt(f)
176    }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
180pub struct AgentModelInfo {
181    pub id: AgentModelId,
182    pub name: SharedString,
183    pub icon: Option<IconName>,
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Hash)]
187pub struct AgentModelGroupName(pub SharedString);
188
189#[derive(Debug, Clone)]
190pub enum AgentModelList {
191    Flat(Vec<AgentModelInfo>),
192    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
193}
194
195impl AgentModelList {
196    pub fn is_empty(&self) -> bool {
197        match self {
198            AgentModelList::Flat(models) => models.is_empty(),
199            AgentModelList::Grouped(groups) => groups.is_empty(),
200        }
201    }
202}
203
204#[cfg(feature = "test-support")]
205mod test_support {
206    use std::sync::Arc;
207
208    use collections::HashMap;
209    use futures::{channel::oneshot, future::try_join_all};
210    use gpui::{AppContext as _, WeakEntity};
211    use parking_lot::Mutex;
212
213    use super::*;
214
215    #[derive(Clone, Default)]
216    pub struct StubAgentConnection {
217        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
218        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
219        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
220    }
221
222    struct Session {
223        thread: WeakEntity<AcpThread>,
224        response_tx: Option<oneshot::Sender<acp::StopReason>>,
225    }
226
227    impl StubAgentConnection {
228        pub fn new() -> Self {
229            Self {
230                next_prompt_updates: Default::default(),
231                permission_requests: HashMap::default(),
232                sessions: Arc::default(),
233            }
234        }
235
236        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
237            *self.next_prompt_updates.lock() = updates;
238        }
239
240        pub fn with_permission_requests(
241            mut self,
242            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
243        ) -> Self {
244            self.permission_requests = permission_requests;
245            self
246        }
247
248        pub fn send_update(
249            &self,
250            session_id: acp::SessionId,
251            update: acp::SessionUpdate,
252            cx: &mut App,
253        ) {
254            assert!(
255                self.next_prompt_updates.lock().is_empty(),
256                "Use either send_update or set_next_prompt_updates"
257            );
258
259            self.sessions
260                .lock()
261                .get(&session_id)
262                .unwrap()
263                .thread
264                .update(cx, |thread, cx| {
265                    thread.handle_session_update(update, cx).unwrap();
266                })
267                .unwrap();
268        }
269
270        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
271            self.sessions
272                .lock()
273                .get_mut(&session_id)
274                .unwrap()
275                .response_tx
276                .take()
277                .expect("No pending turn")
278                .send(stop_reason)
279                .unwrap();
280        }
281    }
282
283    impl AgentConnection for StubAgentConnection {
284        fn auth_methods(&self) -> &[acp::AuthMethod] {
285            &[]
286        }
287
288        fn new_thread(
289            self: Rc<Self>,
290            project: Entity<Project>,
291            _cwd: &Path,
292            cx: &mut gpui::App,
293        ) -> Task<gpui::Result<Entity<AcpThread>>> {
294            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
295            let thread =
296                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
297            self.sessions.lock().insert(
298                session_id,
299                Session {
300                    thread: thread.downgrade(),
301                    response_tx: None,
302                },
303            );
304            Task::ready(Ok(thread))
305        }
306
307        fn authenticate(
308            &self,
309            _method_id: acp::AuthMethodId,
310            _cx: &mut App,
311        ) -> Task<gpui::Result<()>> {
312            unimplemented!()
313        }
314
315        fn prompt(
316            &self,
317            _id: Option<UserMessageId>,
318            params: acp::PromptRequest,
319            cx: &mut App,
320        ) -> Task<gpui::Result<acp::PromptResponse>> {
321            let mut sessions = self.sessions.lock();
322            let Session {
323                thread,
324                response_tx,
325            } = sessions.get_mut(&params.session_id).unwrap();
326            let mut tasks = vec![];
327            if self.next_prompt_updates.lock().is_empty() {
328                let (tx, rx) = oneshot::channel();
329                response_tx.replace(tx);
330                cx.spawn(async move |_| {
331                    let stop_reason = rx.await?;
332                    Ok(acp::PromptResponse { stop_reason })
333                })
334            } else {
335                for update in self.next_prompt_updates.lock().drain(..) {
336                    let thread = thread.clone();
337                    let update = update.clone();
338                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
339                        &update
340                        && let Some(options) = self.permission_requests.get(&tool_call.id)
341                    {
342                        Some((tool_call.clone(), options.clone()))
343                    } else {
344                        None
345                    };
346                    let task = cx.spawn(async move |cx| {
347                        if let Some((tool_call, options)) = permission_request {
348                            let permission = thread.update(cx, |thread, cx| {
349                                thread.request_tool_call_authorization(
350                                    tool_call.clone().into(),
351                                    options.clone(),
352                                    cx,
353                                )
354                            })?;
355                            permission?.await?;
356                        }
357                        thread.update(cx, |thread, cx| {
358                            thread.handle_session_update(update.clone(), cx).unwrap();
359                        })?;
360                        anyhow::Ok(())
361                    });
362                    tasks.push(task);
363                }
364
365                cx.spawn(async move |_| {
366                    try_join_all(tasks).await?;
367                    Ok(acp::PromptResponse {
368                        stop_reason: acp::StopReason::EndTurn,
369                    })
370                })
371            }
372        }
373
374        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
375            if let Some(end_turn_tx) = self
376                .sessions
377                .lock()
378                .get_mut(session_id)
379                .unwrap()
380                .response_tx
381                .take()
382            {
383                end_turn_tx.send(acp::StopReason::Canceled).unwrap();
384            }
385        }
386
387        fn session_editor(
388            &self,
389            _session_id: &agent_client_protocol::SessionId,
390            _cx: &mut App,
391        ) -> Option<Rc<dyn AgentSessionEditor>> {
392            Some(Rc::new(StubAgentSessionEditor))
393        }
394
395        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
396            self
397        }
398    }
399
400    struct StubAgentSessionEditor;
401
402    impl AgentSessionEditor for StubAgentSessionEditor {
403        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
404            Task::ready(Ok(()))
405        }
406    }
407}
408
409#[cfg(feature = "test-support")]
410pub use test_support::*;