connection.rs

  1use crate::{AcpThread, AcpThreadMetadata};
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use futures::channel::mpsc::UnboundedReceiver;
  6use gpui::{Entity, SharedString, Task};
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn list_threads(
 31        &self,
 32        _cx: &mut App,
 33    ) -> Option<watch::Receiver<Option<Vec<AcpThreadMetadata>>>> {
 34        return None;
 35    }
 36
 37    fn load_thread(
 38        self: Rc<Self>,
 39        _project: Entity<Project>,
 40        _cwd: &Path,
 41        _session_id: acp::SessionId,
 42        _cx: &mut App,
 43    ) -> Task<Result<Entity<AcpThread>>> {
 44        Task::ready(Err(anyhow::anyhow!("load thread not implemented")))
 45    }
 46
 47    fn auth_methods(&self) -> &[acp::AuthMethod];
 48
 49    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 50
 51    fn prompt(
 52        &self,
 53        user_message_id: Option<UserMessageId>,
 54        params: acp::PromptRequest,
 55        cx: &mut App,
 56    ) -> Task<Result<acp::PromptResponse>>;
 57
 58    fn resume(
 59        &self,
 60        _session_id: &acp::SessionId,
 61        _cx: &mut App,
 62    ) -> Option<Rc<dyn AgentSessionResume>> {
 63        None
 64    }
 65
 66    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 67
 68    fn session_editor(
 69        &self,
 70        _session_id: &acp::SessionId,
 71        _cx: &mut App,
 72    ) -> Option<Rc<dyn AgentSessionEditor>> {
 73        None
 74    }
 75
 76    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 77    ///
 78    /// If the agent does not support model selection, returns [None].
 79    /// This allows sharing the selector in UI components.
 80    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 81        None
 82    }
 83
 84    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 85}
 86
 87impl dyn AgentConnection {
 88    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 89        self.into_any().downcast().ok()
 90    }
 91}
 92
 93pub trait AgentSessionEditor {
 94    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 95}
 96
 97pub trait AgentSessionResume {
 98    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 99}
100
101#[derive(Debug)]
102pub struct AuthRequired;
103
104impl Error for AuthRequired {}
105impl fmt::Display for AuthRequired {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        write!(f, "AuthRequired")
108    }
109}
110
111/// Trait for agents that support listing, selecting, and querying language models.
112///
113/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
114pub trait AgentModelSelector: 'static {
115    /// Lists all available language models for this agent.
116    ///
117    /// # Parameters
118    /// - `cx`: The GPUI app context for async operations and global access.
119    ///
120    /// # Returns
121    /// A task resolving to the list of models or an error (e.g., if no models are configured).
122    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
123
124    /// Selects a model for a specific session (thread).
125    ///
126    /// This sets the default model for future interactions in the session.
127    /// If the session doesn't exist or the model is invalid, it returns an error.
128    ///
129    /// # Parameters
130    /// - `session_id`: The ID of the session (thread) to apply the model to.
131    /// - `model`: The model to select (should be one from [list_models]).
132    /// - `cx`: The GPUI app context.
133    ///
134    /// # Returns
135    /// A task resolving to `Ok(())` on success or an error.
136    fn select_model(
137        &self,
138        session_id: acp::SessionId,
139        model_id: AgentModelId,
140        cx: &mut App,
141    ) -> Task<Result<()>>;
142
143    /// Retrieves the currently selected model for a specific session (thread).
144    ///
145    /// # Parameters
146    /// - `session_id`: The ID of the session (thread) to query.
147    /// - `cx`: The GPUI app context.
148    ///
149    /// # Returns
150    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
151    fn selected_model(
152        &self,
153        session_id: &acp::SessionId,
154        cx: &mut App,
155    ) -> Task<Result<AgentModelInfo>>;
156
157    /// Whenever the model list is updated the receiver will be notified.
158    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
159}
160
161#[derive(Debug, Clone, PartialEq, Eq, Hash)]
162pub struct AgentModelId(pub SharedString);
163
164impl std::ops::Deref for AgentModelId {
165    type Target = SharedString;
166
167    fn deref(&self) -> &Self::Target {
168        &self.0
169    }
170}
171
172impl fmt::Display for AgentModelId {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        self.0.fmt(f)
175    }
176}
177
178#[derive(Debug, Clone, PartialEq, Eq)]
179pub struct AgentModelInfo {
180    pub id: AgentModelId,
181    pub name: SharedString,
182    pub icon: Option<IconName>,
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Hash)]
186pub struct AgentModelGroupName(pub SharedString);
187
188#[derive(Debug, Clone)]
189pub enum AgentModelList {
190    Flat(Vec<AgentModelInfo>),
191    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
192}
193
194impl AgentModelList {
195    pub fn is_empty(&self) -> bool {
196        match self {
197            AgentModelList::Flat(models) => models.is_empty(),
198            AgentModelList::Grouped(groups) => groups.is_empty(),
199        }
200    }
201}
202
203#[cfg(feature = "test-support")]
204mod test_support {
205    use std::sync::Arc;
206
207    use collections::HashMap;
208    use futures::future::try_join_all;
209    use gpui::{AppContext as _, WeakEntity};
210    use parking_lot::Mutex;
211
212    use super::*;
213
214    #[derive(Clone, Default)]
215    pub struct StubAgentConnection {
216        sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
217        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
218        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
219    }
220
221    impl StubAgentConnection {
222        pub fn new() -> Self {
223            Self {
224                next_prompt_updates: Default::default(),
225                permission_requests: HashMap::default(),
226                sessions: Arc::default(),
227            }
228        }
229
230        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
231            *self.next_prompt_updates.lock() = updates;
232        }
233
234        pub fn with_permission_requests(
235            mut self,
236            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
237        ) -> Self {
238            self.permission_requests = permission_requests;
239            self
240        }
241
242        pub fn send_update(
243            &self,
244            session_id: acp::SessionId,
245            update: acp::SessionUpdate,
246            cx: &mut App,
247        ) {
248            self.sessions
249                .lock()
250                .get(&session_id)
251                .unwrap()
252                .update(cx, |thread, cx| {
253                    thread.handle_session_update(update.clone(), cx).unwrap();
254                })
255                .unwrap();
256        }
257    }
258
259    impl AgentConnection for StubAgentConnection {
260        fn auth_methods(&self) -> &[acp::AuthMethod] {
261            &[]
262        }
263
264        fn new_thread(
265            self: Rc<Self>,
266            project: Entity<Project>,
267            _cwd: &Path,
268            cx: &mut gpui::App,
269        ) -> Task<gpui::Result<Entity<AcpThread>>> {
270            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
271            let thread =
272                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
273            self.sessions.lock().insert(session_id, thread.downgrade());
274            Task::ready(Ok(thread))
275        }
276
277        fn authenticate(
278            &self,
279            _method_id: acp::AuthMethodId,
280            _cx: &mut App,
281        ) -> Task<gpui::Result<()>> {
282            unimplemented!()
283        }
284
285        fn prompt(
286            &self,
287            _id: Option<UserMessageId>,
288            params: acp::PromptRequest,
289            cx: &mut App,
290        ) -> Task<gpui::Result<acp::PromptResponse>> {
291            let sessions = self.sessions.lock();
292            let thread = sessions.get(&params.session_id).unwrap();
293            let mut tasks = vec![];
294            for update in self.next_prompt_updates.lock().drain(..) {
295                let thread = thread.clone();
296                let update = update.clone();
297                let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
298                    && let Some(options) = self.permission_requests.get(&tool_call.id)
299                {
300                    Some((tool_call.clone(), options.clone()))
301                } else {
302                    None
303                };
304                let task = cx.spawn(async move |cx| {
305                    if let Some((tool_call, options)) = permission_request {
306                        let permission = thread.update(cx, |thread, cx| {
307                            thread.request_tool_call_authorization(
308                                tool_call.clone().into(),
309                                options.clone(),
310                                cx,
311                            )
312                        })?;
313                        permission?.await?;
314                    }
315                    thread.update(cx, |thread, cx| {
316                        thread.handle_session_update(update.clone(), cx).unwrap();
317                    })?;
318                    anyhow::Ok(())
319                });
320                tasks.push(task);
321            }
322            cx.spawn(async move |_| {
323                try_join_all(tasks).await?;
324                Ok(acp::PromptResponse {
325                    stop_reason: acp::StopReason::EndTurn,
326                })
327            })
328        }
329
330        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
331            unimplemented!()
332        }
333
334        fn session_editor(
335            &self,
336            _session_id: &agent_client_protocol::SessionId,
337            _cx: &mut App,
338        ) -> Option<Rc<dyn AgentSessionEditor>> {
339            Some(Rc::new(StubAgentSessionEditor))
340        }
341
342        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
343            self
344        }
345    }
346
347    struct StubAgentSessionEditor;
348
349    impl AgentSessionEditor for StubAgentSessionEditor {
350        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
351            Task::ready(Ok(()))
352        }
353    }
354}
355
356#[cfg(feature = "test-support")]
357pub use test_support::*;