connection.rs

  1use crate::{AcpThread, AcpThreadMetadata};
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use project::Project;
  7use serde::{Deserialize, Serialize};
  8use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
  9use ui::{App, IconName};
 10use uuid::Uuid;
 11
 12#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 13pub struct UserMessageId(Arc<str>);
 14
 15impl UserMessageId {
 16    pub fn new() -> Self {
 17        Self(Uuid::new_v4().to_string().into())
 18    }
 19}
 20
 21pub trait AgentConnection {
 22    fn new_thread(
 23        self: Rc<Self>,
 24        project: Entity<Project>,
 25        cwd: &Path,
 26        cx: &mut App,
 27    ) -> Task<Result<Entity<AcpThread>>>;
 28
 29    // todo!(expose a history trait, and include list_threads and load_thread)
 30    // todo!(write a test)
 31    fn list_threads(
 32        &self,
 33        _cx: &mut App,
 34    ) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
 35        return None;
 36    }
 37
 38    fn load_thread(
 39        self: Rc<Self>,
 40        _project: Entity<Project>,
 41        _cwd: &Path,
 42        _session_id: acp::SessionId,
 43        _cx: &mut App,
 44    ) -> Task<Result<Entity<AcpThread>>> {
 45        Task::ready(Err(anyhow::anyhow!("load thread not implemented")))
 46    }
 47
 48    fn auth_methods(&self) -> &[acp::AuthMethod];
 49
 50    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 51
 52    fn prompt(
 53        &self,
 54        user_message_id: Option<UserMessageId>,
 55        params: acp::PromptRequest,
 56        cx: &mut App,
 57    ) -> Task<Result<acp::PromptResponse>>;
 58
 59    fn resume(
 60        &self,
 61        _session_id: &acp::SessionId,
 62        _cx: &mut App,
 63    ) -> Option<Rc<dyn AgentSessionResume>> {
 64        None
 65    }
 66
 67    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 68
 69    fn session_editor(
 70        &self,
 71        _session_id: &acp::SessionId,
 72        _cx: &mut App,
 73    ) -> Option<Rc<dyn AgentSessionEditor>> {
 74        None
 75    }
 76
 77    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 78    ///
 79    /// If the agent does not support model selection, returns [None].
 80    /// This allows sharing the selector in UI components.
 81    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 82        None
 83    }
 84
 85    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 86}
 87
 88impl dyn AgentConnection {
 89    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 90        self.into_any().downcast().ok()
 91    }
 92}
 93
 94pub trait AgentSessionEditor {
 95    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 96}
 97
 98pub trait AgentSessionResume {
 99    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
100}
101
102#[derive(Debug)]
103pub struct AuthRequired;
104
105impl Error for AuthRequired {}
106impl fmt::Display for AuthRequired {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        write!(f, "AuthRequired")
109    }
110}
111
112/// Trait for agents that support listing, selecting, and querying language models.
113///
114/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
115pub trait AgentModelSelector: 'static {
116    /// Lists all available language models for this agent.
117    ///
118    /// # Parameters
119    /// - `cx`: The GPUI app context for async operations and global access.
120    ///
121    /// # Returns
122    /// A task resolving to the list of models or an error (e.g., if no models are configured).
123    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
124
125    /// Selects a model for a specific session (thread).
126    ///
127    /// This sets the default model for future interactions in the session.
128    /// If the session doesn't exist or the model is invalid, it returns an error.
129    ///
130    /// # Parameters
131    /// - `session_id`: The ID of the session (thread) to apply the model to.
132    /// - `model`: The model to select (should be one from [list_models]).
133    /// - `cx`: The GPUI app context.
134    ///
135    /// # Returns
136    /// A task resolving to `Ok(())` on success or an error.
137    fn select_model(
138        &self,
139        session_id: acp::SessionId,
140        model_id: AgentModelId,
141        cx: &mut App,
142    ) -> Task<Result<()>>;
143
144    /// Retrieves the currently selected model for a specific session (thread).
145    ///
146    /// # Parameters
147    /// - `session_id`: The ID of the session (thread) to query.
148    /// - `cx`: The GPUI app context.
149    ///
150    /// # Returns
151    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
152    fn selected_model(
153        &self,
154        session_id: &acp::SessionId,
155        cx: &mut App,
156    ) -> Task<Result<AgentModelInfo>>;
157
158    /// Whenever the model list is updated the receiver will be notified.
159    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
160}
161
162#[derive(Debug, Clone, PartialEq, Eq, Hash)]
163pub struct AgentModelId(pub SharedString);
164
165impl std::ops::Deref for AgentModelId {
166    type Target = SharedString;
167
168    fn deref(&self) -> &Self::Target {
169        &self.0
170    }
171}
172
173impl fmt::Display for AgentModelId {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        self.0.fmt(f)
176    }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq)]
180pub struct AgentModelInfo {
181    pub id: AgentModelId,
182    pub name: SharedString,
183    pub icon: Option<IconName>,
184}
185
186#[derive(Debug, Clone, PartialEq, Eq, Hash)]
187pub struct AgentModelGroupName(pub SharedString);
188
189#[derive(Debug, Clone)]
190pub enum AgentModelList {
191    Flat(Vec<AgentModelInfo>),
192    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
193}
194
195impl AgentModelList {
196    pub fn is_empty(&self) -> bool {
197        match self {
198            AgentModelList::Flat(models) => models.is_empty(),
199            AgentModelList::Grouped(groups) => groups.is_empty(),
200        }
201    }
202}
203
204#[cfg(feature = "test-support")]
205mod test_support {
206    use std::sync::Arc;
207
208    use collections::HashMap;
209    use futures::future::try_join_all;
210    use gpui::{AppContext as _, WeakEntity};
211    use parking_lot::Mutex;
212
213    use super::*;
214
215    #[derive(Clone, Default)]
216    pub struct StubAgentConnection {
217        sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
218        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
219        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
220    }
221
222    impl StubAgentConnection {
223        pub fn new() -> Self {
224            Self {
225                next_prompt_updates: Default::default(),
226                permission_requests: HashMap::default(),
227                sessions: Arc::default(),
228            }
229        }
230
231        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
232            *self.next_prompt_updates.lock() = updates;
233        }
234
235        pub fn with_permission_requests(
236            mut self,
237            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
238        ) -> Self {
239            self.permission_requests = permission_requests;
240            self
241        }
242
243        pub fn send_update(
244            &self,
245            session_id: acp::SessionId,
246            update: acp::SessionUpdate,
247            cx: &mut App,
248        ) {
249            self.sessions
250                .lock()
251                .get(&session_id)
252                .unwrap()
253                .update(cx, |thread, cx| {
254                    thread.handle_session_update(update.clone(), cx).unwrap();
255                })
256                .unwrap();
257        }
258    }
259
260    impl AgentConnection for StubAgentConnection {
261        fn auth_methods(&self) -> &[acp::AuthMethod] {
262            &[]
263        }
264
265        fn new_thread(
266            self: Rc<Self>,
267            project: Entity<Project>,
268            _cwd: &Path,
269            cx: &mut gpui::App,
270        ) -> Task<gpui::Result<Entity<AcpThread>>> {
271            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
272            let thread =
273                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
274            self.sessions.lock().insert(session_id, thread.downgrade());
275            Task::ready(Ok(thread))
276        }
277
278        fn authenticate(
279            &self,
280            _method_id: acp::AuthMethodId,
281            _cx: &mut App,
282        ) -> Task<gpui::Result<()>> {
283            unimplemented!()
284        }
285
286        fn prompt(
287            &self,
288            _id: Option<UserMessageId>,
289            params: acp::PromptRequest,
290            cx: &mut App,
291        ) -> Task<gpui::Result<acp::PromptResponse>> {
292            let sessions = self.sessions.lock();
293            let thread = sessions.get(&params.session_id).unwrap();
294            let mut tasks = vec![];
295            for update in self.next_prompt_updates.lock().drain(..) {
296                let thread = thread.clone();
297                let update = update.clone();
298                let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
299                    && let Some(options) = self.permission_requests.get(&tool_call.id)
300                {
301                    Some((tool_call.clone(), options.clone()))
302                } else {
303                    None
304                };
305                let task = cx.spawn(async move |cx| {
306                    if let Some((tool_call, options)) = permission_request {
307                        let permission = thread.update(cx, |thread, cx| {
308                            thread.request_tool_call_authorization(
309                                tool_call.clone().into(),
310                                options.clone(),
311                                cx,
312                            )
313                        })?;
314                        permission?.await?;
315                    }
316                    thread.update(cx, |thread, cx| {
317                        thread.handle_session_update(update.clone(), cx).unwrap();
318                    })?;
319                    anyhow::Ok(())
320                });
321                tasks.push(task);
322            }
323            cx.spawn(async move |_| {
324                try_join_all(tasks).await?;
325                Ok(acp::PromptResponse {
326                    stop_reason: acp::StopReason::EndTurn,
327                })
328            })
329        }
330
331        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
332            unimplemented!()
333        }
334
335        fn session_editor(
336            &self,
337            _session_id: &agent_client_protocol::SessionId,
338            _cx: &mut App,
339        ) -> Option<Rc<dyn AgentSessionEditor>> {
340            Some(Rc::new(StubAgentSessionEditor))
341        }
342
343        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
344            self
345        }
346    }
347
348    struct StubAgentSessionEditor;
349
350    impl AgentSessionEditor for StubAgentSessionEditor {
351        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
352            Task::ready(Ok(()))
353        }
354    }
355}
356
357#[cfg(feature = "test-support")]
358pub use test_support::*;