connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use project::Project;
  7use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
  8use ui::{App, IconName};
  9use uuid::Uuid;
 10
 11#[derive(Clone, Debug, Eq, PartialEq)]
 12pub struct UserMessageId(Arc<str>);
 13
 14impl UserMessageId {
 15    pub fn new() -> Self {
 16        Self(Uuid::new_v4().to_string().into())
 17    }
 18}
 19
 20pub trait AgentConnection {
 21    fn new_thread(
 22        self: Rc<Self>,
 23        project: Entity<Project>,
 24        cwd: &Path,
 25        cx: &mut App,
 26    ) -> Task<Result<Entity<AcpThread>>>;
 27
 28    fn auth_methods(&self) -> &[acp::AuthMethod];
 29
 30    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 31
 32    fn prompt(
 33        &self,
 34        user_message_id: Option<UserMessageId>,
 35        params: acp::PromptRequest,
 36        cx: &mut App,
 37    ) -> Task<Result<acp::PromptResponse>>;
 38
 39    fn resume(
 40        &self,
 41        _session_id: &acp::SessionId,
 42        _cx: &mut App,
 43    ) -> Option<Rc<dyn AgentSessionResume>> {
 44        None
 45    }
 46
 47    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 48
 49    fn session_editor(
 50        &self,
 51        _session_id: &acp::SessionId,
 52        _cx: &mut App,
 53    ) -> Option<Rc<dyn AgentSessionEditor>> {
 54        None
 55    }
 56
 57    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 58    ///
 59    /// If the agent does not support model selection, returns [None].
 60    /// This allows sharing the selector in UI components.
 61    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 62        None
 63    }
 64
 65    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 66}
 67
 68impl dyn AgentConnection {
 69    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 70        self.into_any().downcast().ok()
 71    }
 72}
 73
 74pub trait AgentSessionEditor {
 75    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 76}
 77
 78pub trait AgentSessionResume {
 79    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 80}
 81
 82#[derive(Debug)]
 83pub struct AuthRequired;
 84
 85impl Error for AuthRequired {}
 86impl fmt::Display for AuthRequired {
 87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 88        write!(f, "AuthRequired")
 89    }
 90}
 91
 92/// Trait for agents that support listing, selecting, and querying language models.
 93///
 94/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
 95pub trait AgentModelSelector: 'static {
 96    /// Lists all available language models for this agent.
 97    ///
 98    /// # Parameters
 99    /// - `cx`: The GPUI app context for async operations and global access.
100    ///
101    /// # Returns
102    /// A task resolving to the list of models or an error (e.g., if no models are configured).
103    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
104
105    /// Selects a model for a specific session (thread).
106    ///
107    /// This sets the default model for future interactions in the session.
108    /// If the session doesn't exist or the model is invalid, it returns an error.
109    ///
110    /// # Parameters
111    /// - `session_id`: The ID of the session (thread) to apply the model to.
112    /// - `model`: The model to select (should be one from [list_models]).
113    /// - `cx`: The GPUI app context.
114    ///
115    /// # Returns
116    /// A task resolving to `Ok(())` on success or an error.
117    fn select_model(
118        &self,
119        session_id: acp::SessionId,
120        model_id: AgentModelId,
121        cx: &mut App,
122    ) -> Task<Result<()>>;
123
124    /// Retrieves the currently selected model for a specific session (thread).
125    ///
126    /// # Parameters
127    /// - `session_id`: The ID of the session (thread) to query.
128    /// - `cx`: The GPUI app context.
129    ///
130    /// # Returns
131    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
132    fn selected_model(
133        &self,
134        session_id: &acp::SessionId,
135        cx: &mut App,
136    ) -> Task<Result<AgentModelInfo>>;
137
138    /// Whenever the model list is updated the receiver will be notified.
139    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Hash)]
143pub struct AgentModelId(pub SharedString);
144
145impl std::ops::Deref for AgentModelId {
146    type Target = SharedString;
147
148    fn deref(&self) -> &Self::Target {
149        &self.0
150    }
151}
152
153impl fmt::Display for AgentModelId {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        self.0.fmt(f)
156    }
157}
158
159#[derive(Debug, Clone, PartialEq, Eq)]
160pub struct AgentModelInfo {
161    pub id: AgentModelId,
162    pub name: SharedString,
163    pub icon: Option<IconName>,
164}
165
166#[derive(Debug, Clone, PartialEq, Eq, Hash)]
167pub struct AgentModelGroupName(pub SharedString);
168
169#[derive(Debug, Clone)]
170pub enum AgentModelList {
171    Flat(Vec<AgentModelInfo>),
172    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
173}
174
175impl AgentModelList {
176    pub fn is_empty(&self) -> bool {
177        match self {
178            AgentModelList::Flat(models) => models.is_empty(),
179            AgentModelList::Grouped(groups) => groups.is_empty(),
180        }
181    }
182}
183
184#[cfg(feature = "test-support")]
185mod test_support {
186    use std::sync::Arc;
187
188    use collections::HashMap;
189    use futures::future::try_join_all;
190    use gpui::{AppContext as _, WeakEntity};
191    use parking_lot::Mutex;
192
193    use super::*;
194
195    #[derive(Clone, Default)]
196    pub struct StubAgentConnection {
197        sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
198        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
199        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
200    }
201
202    impl StubAgentConnection {
203        pub fn new() -> Self {
204            Self {
205                next_prompt_updates: Default::default(),
206                permission_requests: HashMap::default(),
207                sessions: Arc::default(),
208            }
209        }
210
211        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
212            *self.next_prompt_updates.lock() = updates;
213        }
214
215        pub fn with_permission_requests(
216            mut self,
217            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
218        ) -> Self {
219            self.permission_requests = permission_requests;
220            self
221        }
222
223        pub fn send_update(
224            &self,
225            session_id: acp::SessionId,
226            update: acp::SessionUpdate,
227            cx: &mut App,
228        ) {
229            self.sessions
230                .lock()
231                .get(&session_id)
232                .unwrap()
233                .update(cx, |thread, cx| {
234                    thread.handle_session_update(update.clone(), cx).unwrap();
235                })
236                .unwrap();
237        }
238    }
239
240    impl AgentConnection for StubAgentConnection {
241        fn auth_methods(&self) -> &[acp::AuthMethod] {
242            &[]
243        }
244
245        fn new_thread(
246            self: Rc<Self>,
247            project: Entity<Project>,
248            _cwd: &Path,
249            cx: &mut gpui::App,
250        ) -> Task<gpui::Result<Entity<AcpThread>>> {
251            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
252            let thread =
253                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
254            self.sessions.lock().insert(session_id, thread.downgrade());
255            Task::ready(Ok(thread))
256        }
257
258        fn authenticate(
259            &self,
260            _method_id: acp::AuthMethodId,
261            _cx: &mut App,
262        ) -> Task<gpui::Result<()>> {
263            unimplemented!()
264        }
265
266        fn prompt(
267            &self,
268            _id: Option<UserMessageId>,
269            params: acp::PromptRequest,
270            cx: &mut App,
271        ) -> Task<gpui::Result<acp::PromptResponse>> {
272            let sessions = self.sessions.lock();
273            let thread = sessions.get(&params.session_id).unwrap();
274            let mut tasks = vec![];
275            for update in self.next_prompt_updates.lock().drain(..) {
276                let thread = thread.clone();
277                let update = update.clone();
278                let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
279                    && let Some(options) = self.permission_requests.get(&tool_call.id)
280                {
281                    Some((tool_call.clone(), options.clone()))
282                } else {
283                    None
284                };
285                let task = cx.spawn(async move |cx| {
286                    if let Some((tool_call, options)) = permission_request {
287                        let permission = thread.update(cx, |thread, cx| {
288                            thread.request_tool_call_authorization(
289                                tool_call.clone().into(),
290                                options.clone(),
291                                cx,
292                            )
293                        })?;
294                        permission?.await?;
295                    }
296                    thread.update(cx, |thread, cx| {
297                        thread.handle_session_update(update.clone(), cx).unwrap();
298                    })?;
299                    anyhow::Ok(())
300                });
301                tasks.push(task);
302            }
303            cx.spawn(async move |_| {
304                try_join_all(tasks).await?;
305                Ok(acp::PromptResponse {
306                    stop_reason: acp::StopReason::EndTurn,
307                })
308            })
309        }
310
311        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
312            unimplemented!()
313        }
314
315        fn session_editor(
316            &self,
317            _session_id: &agent_client_protocol::SessionId,
318            _cx: &mut App,
319        ) -> Option<Rc<dyn AgentSessionEditor>> {
320            Some(Rc::new(StubAgentSessionEditor))
321        }
322
323        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
324            self
325        }
326    }
327
328    struct StubAgentSessionEditor;
329
330    impl AgentSessionEditor for StubAgentSessionEditor {
331        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
332            Task::ready(Ok(()))
333        }
334    }
335}
336
337#[cfg(feature = "test-support")]
338pub use test_support::*;