connection.rs

  1use crate::{AcpThread, AcpThreadMetadata};
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use futures::channel::mpsc::UnboundedReceiver;
  6use gpui::{Entity, SharedString, Task};
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn auth_methods(&self) -> &[acp::AuthMethod];
 31
 32    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 33
 34    fn prompt(
 35        &self,
 36        user_message_id: Option<UserMessageId>,
 37        params: acp::PromptRequest,
 38        cx: &mut App,
 39    ) -> Task<Result<acp::PromptResponse>>;
 40
 41    fn resume(
 42        &self,
 43        _session_id: &acp::SessionId,
 44        _cx: &mut App,
 45    ) -> Option<Rc<dyn AgentSessionResume>> {
 46        None
 47    }
 48
 49    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 50
 51    fn session_editor(
 52        &self,
 53        _session_id: &acp::SessionId,
 54        _cx: &mut App,
 55    ) -> Option<Rc<dyn AgentSessionEditor>> {
 56        None
 57    }
 58
 59    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 60    ///
 61    /// If the agent does not support model selection, returns [None].
 62    /// This allows sharing the selector in UI components.
 63    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 64        None
 65    }
 66
 67    fn history(self: Rc<Self>) -> Option<Rc<dyn AgentHistory>> {
 68        None
 69    }
 70
 71    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 72}
 73
 74impl dyn AgentConnection {
 75    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 76        self.into_any().downcast().ok()
 77    }
 78}
 79
 80pub trait AgentSessionEditor {
 81    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 82}
 83
 84pub trait AgentSessionResume {
 85    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 86}
 87
 88pub trait AgentHistory {
 89    fn list_threads(&self, cx: &mut App) -> Task<Result<Vec<AcpThreadMetadata>>>;
 90    fn observe_history(&self, cx: &mut App) -> UnboundedReceiver<AcpThreadMetadata>;
 91    fn load_thread(
 92        self: Rc<Self>,
 93        _project: Entity<Project>,
 94        _cwd: &Path,
 95        _session_id: acp::SessionId,
 96        _cx: &mut App,
 97    ) -> Task<Result<Entity<AcpThread>>>;
 98}
 99
100#[derive(Debug)]
101pub struct AuthRequired;
102
103impl Error for AuthRequired {}
104impl fmt::Display for AuthRequired {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        write!(f, "AuthRequired")
107    }
108}
109
110/// Trait for agents that support listing, selecting, and querying language models.
111///
112/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
113pub trait AgentModelSelector: 'static {
114    /// Lists all available language models for this agent.
115    ///
116    /// # Parameters
117    /// - `cx`: The GPUI app context for async operations and global access.
118    ///
119    /// # Returns
120    /// A task resolving to the list of models or an error (e.g., if no models are configured).
121    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
122
123    /// Selects a model for a specific session (thread).
124    ///
125    /// This sets the default model for future interactions in the session.
126    /// If the session doesn't exist or the model is invalid, it returns an error.
127    ///
128    /// # Parameters
129    /// - `session_id`: The ID of the session (thread) to apply the model to.
130    /// - `model`: The model to select (should be one from [list_models]).
131    /// - `cx`: The GPUI app context.
132    ///
133    /// # Returns
134    /// A task resolving to `Ok(())` on success or an error.
135    fn select_model(
136        &self,
137        session_id: acp::SessionId,
138        model_id: AgentModelId,
139        cx: &mut App,
140    ) -> Task<Result<()>>;
141
142    /// Retrieves the currently selected model for a specific session (thread).
143    ///
144    /// # Parameters
145    /// - `session_id`: The ID of the session (thread) to query.
146    /// - `cx`: The GPUI app context.
147    ///
148    /// # Returns
149    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
150    fn selected_model(
151        &self,
152        session_id: &acp::SessionId,
153        cx: &mut App,
154    ) -> Task<Result<AgentModelInfo>>;
155
156    /// Whenever the model list is updated the receiver will be notified.
157    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
158}
159
160#[derive(Debug, Clone, PartialEq, Eq, Hash)]
161pub struct AgentModelId(pub SharedString);
162
163impl std::ops::Deref for AgentModelId {
164    type Target = SharedString;
165
166    fn deref(&self) -> &Self::Target {
167        &self.0
168    }
169}
170
171impl fmt::Display for AgentModelId {
172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        self.0.fmt(f)
174    }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
178pub struct AgentModelInfo {
179    pub id: AgentModelId,
180    pub name: SharedString,
181    pub icon: Option<IconName>,
182}
183
184#[derive(Debug, Clone, PartialEq, Eq, Hash)]
185pub struct AgentModelGroupName(pub SharedString);
186
187#[derive(Debug, Clone)]
188pub enum AgentModelList {
189    Flat(Vec<AgentModelInfo>),
190    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
191}
192
193impl AgentModelList {
194    pub fn is_empty(&self) -> bool {
195        match self {
196            AgentModelList::Flat(models) => models.is_empty(),
197            AgentModelList::Grouped(groups) => groups.is_empty(),
198        }
199    }
200}
201
202#[cfg(feature = "test-support")]
203mod test_support {
204    use std::sync::Arc;
205
206    use collections::HashMap;
207    use futures::{channel::oneshot, future::try_join_all};
208    use gpui::{AppContext as _, WeakEntity};
209    use parking_lot::Mutex;
210
211    use super::*;
212
213    #[derive(Clone, Default)]
214    pub struct StubAgentConnection {
215        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
216        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
217        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
218    }
219
220    struct Session {
221        thread: WeakEntity<AcpThread>,
222        response_tx: Option<oneshot::Sender<acp::StopReason>>,
223    }
224
225    impl StubAgentConnection {
226        pub fn new() -> Self {
227            Self {
228                next_prompt_updates: Default::default(),
229                permission_requests: HashMap::default(),
230                sessions: Arc::default(),
231            }
232        }
233
234        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
235            *self.next_prompt_updates.lock() = updates;
236        }
237
238        pub fn with_permission_requests(
239            mut self,
240            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
241        ) -> Self {
242            self.permission_requests = permission_requests;
243            self
244        }
245
246        pub fn send_update(
247            &self,
248            session_id: acp::SessionId,
249            update: acp::SessionUpdate,
250            cx: &mut App,
251        ) {
252            assert!(
253                self.next_prompt_updates.lock().is_empty(),
254                "Use either send_update or set_next_prompt_updates"
255            );
256
257            self.sessions
258                .lock()
259                .get(&session_id)
260                .unwrap()
261                .thread
262                .update(cx, |thread, cx| {
263                    thread.handle_session_update(update, cx).unwrap();
264                })
265                .unwrap();
266        }
267
268        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
269            self.sessions
270                .lock()
271                .get_mut(&session_id)
272                .unwrap()
273                .response_tx
274                .take()
275                .expect("No pending turn")
276                .send(stop_reason)
277                .unwrap();
278        }
279    }
280
281    impl AgentConnection for StubAgentConnection {
282        fn auth_methods(&self) -> &[acp::AuthMethod] {
283            &[]
284        }
285
286        fn new_thread(
287            self: Rc<Self>,
288            project: Entity<Project>,
289            _cwd: &Path,
290            cx: &mut gpui::App,
291        ) -> Task<gpui::Result<Entity<AcpThread>>> {
292            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
293            let thread =
294                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
295            self.sessions.lock().insert(
296                session_id,
297                Session {
298                    thread: thread.downgrade(),
299                    response_tx: None,
300                },
301            );
302            Task::ready(Ok(thread))
303        }
304
305        fn authenticate(
306            &self,
307            _method_id: acp::AuthMethodId,
308            _cx: &mut App,
309        ) -> Task<gpui::Result<()>> {
310            unimplemented!()
311        }
312
313        fn prompt(
314            &self,
315            _id: Option<UserMessageId>,
316            params: acp::PromptRequest,
317            cx: &mut App,
318        ) -> Task<gpui::Result<acp::PromptResponse>> {
319            let mut sessions = self.sessions.lock();
320            let Session {
321                thread,
322                response_tx,
323            } = sessions.get_mut(&params.session_id).unwrap();
324            let mut tasks = vec![];
325            if self.next_prompt_updates.lock().is_empty() {
326                let (tx, rx) = oneshot::channel();
327                response_tx.replace(tx);
328                cx.spawn(async move |_| {
329                    let stop_reason = rx.await?;
330                    Ok(acp::PromptResponse { stop_reason })
331                })
332            } else {
333                for update in self.next_prompt_updates.lock().drain(..) {
334                    let thread = thread.clone();
335                    let update = update.clone();
336                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
337                        &update
338                        && let Some(options) = self.permission_requests.get(&tool_call.id)
339                    {
340                        Some((tool_call.clone(), options.clone()))
341                    } else {
342                        None
343                    };
344                    let task = cx.spawn(async move |cx| {
345                        if let Some((tool_call, options)) = permission_request {
346                            let permission = thread.update(cx, |thread, cx| {
347                                thread.request_tool_call_authorization(
348                                    tool_call.clone().into(),
349                                    options.clone(),
350                                    cx,
351                                )
352                            })?;
353                            permission?.await?;
354                        }
355                        thread.update(cx, |thread, cx| {
356                            thread.handle_session_update(update.clone(), cx).unwrap();
357                        })?;
358                        anyhow::Ok(())
359                    });
360                    tasks.push(task);
361                }
362
363                cx.spawn(async move |_| {
364                    try_join_all(tasks).await?;
365                    Ok(acp::PromptResponse {
366                        stop_reason: acp::StopReason::EndTurn,
367                    })
368                })
369            }
370        }
371
372        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
373            if let Some(end_turn_tx) = self
374                .sessions
375                .lock()
376                .get_mut(session_id)
377                .unwrap()
378                .response_tx
379                .take()
380            {
381                end_turn_tx.send(acp::StopReason::Canceled).unwrap();
382            }
383        }
384
385        fn session_editor(
386            &self,
387            _session_id: &agent_client_protocol::SessionId,
388            _cx: &mut App,
389        ) -> Option<Rc<dyn AgentSessionEditor>> {
390            Some(Rc::new(StubAgentSessionEditor))
391        }
392
393        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
394            self
395        }
396    }
397
398    struct StubAgentSessionEditor;
399
400    impl AgentSessionEditor for StubAgentSessionEditor {
401        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
402            Task::ready(Ok(()))
403        }
404    }
405}
406
407#[cfg(feature = "test-support")]
408pub use test_support::*;