connection.rs

  1use crate::{AcpThread, AcpThreadMetadata};
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use futures::channel::mpsc::UnboundedReceiver;
  6use gpui::{Entity, SharedString, Task};
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn list_threads(&self, _cx: &mut App) -> Option<UnboundedReceiver<Vec<AcpThreadMetadata>>> {
 31        return None;
 32    }
 33
 34    fn load_thread(
 35        self: Rc<Self>,
 36        _project: Entity<Project>,
 37        _cwd: &Path,
 38        _session_id: acp::SessionId,
 39        _cx: &mut App,
 40    ) -> Task<Result<Entity<AcpThread>>> {
 41        Task::ready(Err(anyhow::anyhow!("load thread not implemented")))
 42    }
 43
 44    fn auth_methods(&self) -> &[acp::AuthMethod];
 45
 46    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 47
 48    fn prompt(
 49        &self,
 50        user_message_id: Option<UserMessageId>,
 51        params: acp::PromptRequest,
 52        cx: &mut App,
 53    ) -> Task<Result<acp::PromptResponse>>;
 54
 55    fn resume(
 56        &self,
 57        _session_id: &acp::SessionId,
 58        _cx: &mut App,
 59    ) -> Option<Rc<dyn AgentSessionResume>> {
 60        None
 61    }
 62
 63    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 64
 65    fn session_editor(
 66        &self,
 67        _session_id: &acp::SessionId,
 68        _cx: &mut App,
 69    ) -> Option<Rc<dyn AgentSessionEditor>> {
 70        None
 71    }
 72
 73    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 74    ///
 75    /// If the agent does not support model selection, returns [None].
 76    /// This allows sharing the selector in UI components.
 77    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 78        None
 79    }
 80
 81    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 82}
 83
 84impl dyn AgentConnection {
 85    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 86        self.into_any().downcast().ok()
 87    }
 88}
 89
 90pub trait AgentSessionEditor {
 91    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 92}
 93
 94pub trait AgentSessionResume {
 95    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 96}
 97
 98#[derive(Debug)]
 99pub struct AuthRequired;
100
101impl Error for AuthRequired {}
102impl fmt::Display for AuthRequired {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        write!(f, "AuthRequired")
105    }
106}
107
108/// Trait for agents that support listing, selecting, and querying language models.
109///
110/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
111pub trait AgentModelSelector: 'static {
112    /// Lists all available language models for this agent.
113    ///
114    /// # Parameters
115    /// - `cx`: The GPUI app context for async operations and global access.
116    ///
117    /// # Returns
118    /// A task resolving to the list of models or an error (e.g., if no models are configured).
119    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
120
121    /// Selects a model for a specific session (thread).
122    ///
123    /// This sets the default model for future interactions in the session.
124    /// If the session doesn't exist or the model is invalid, it returns an error.
125    ///
126    /// # Parameters
127    /// - `session_id`: The ID of the session (thread) to apply the model to.
128    /// - `model`: The model to select (should be one from [list_models]).
129    /// - `cx`: The GPUI app context.
130    ///
131    /// # Returns
132    /// A task resolving to `Ok(())` on success or an error.
133    fn select_model(
134        &self,
135        session_id: acp::SessionId,
136        model_id: AgentModelId,
137        cx: &mut App,
138    ) -> Task<Result<()>>;
139
140    /// Retrieves the currently selected model for a specific session (thread).
141    ///
142    /// # Parameters
143    /// - `session_id`: The ID of the session (thread) to query.
144    /// - `cx`: The GPUI app context.
145    ///
146    /// # Returns
147    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
148    fn selected_model(
149        &self,
150        session_id: &acp::SessionId,
151        cx: &mut App,
152    ) -> Task<Result<AgentModelInfo>>;
153
154    /// Whenever the model list is updated the receiver will be notified.
155    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
156}
157
158#[derive(Debug, Clone, PartialEq, Eq, Hash)]
159pub struct AgentModelId(pub SharedString);
160
161impl std::ops::Deref for AgentModelId {
162    type Target = SharedString;
163
164    fn deref(&self) -> &Self::Target {
165        &self.0
166    }
167}
168
169impl fmt::Display for AgentModelId {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        self.0.fmt(f)
172    }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub struct AgentModelInfo {
177    pub id: AgentModelId,
178    pub name: SharedString,
179    pub icon: Option<IconName>,
180}
181
182#[derive(Debug, Clone, PartialEq, Eq, Hash)]
183pub struct AgentModelGroupName(pub SharedString);
184
185#[derive(Debug, Clone)]
186pub enum AgentModelList {
187    Flat(Vec<AgentModelInfo>),
188    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
189}
190
191impl AgentModelList {
192    pub fn is_empty(&self) -> bool {
193        match self {
194            AgentModelList::Flat(models) => models.is_empty(),
195            AgentModelList::Grouped(groups) => groups.is_empty(),
196        }
197    }
198}
199
200#[cfg(feature = "test-support")]
201mod test_support {
202    use std::sync::Arc;
203
204    use collections::HashMap;
205    use futures::future::try_join_all;
206    use gpui::{AppContext as _, WeakEntity};
207    use parking_lot::Mutex;
208
209    use super::*;
210
211    #[derive(Clone, Default)]
212    pub struct StubAgentConnection {
213        sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
214        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
215        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
216    }
217
218    impl StubAgentConnection {
219        pub fn new() -> Self {
220            Self {
221                next_prompt_updates: Default::default(),
222                permission_requests: HashMap::default(),
223                sessions: Arc::default(),
224            }
225        }
226
227        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
228            *self.next_prompt_updates.lock() = updates;
229        }
230
231        pub fn with_permission_requests(
232            mut self,
233            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
234        ) -> Self {
235            self.permission_requests = permission_requests;
236            self
237        }
238
239        pub fn send_update(
240            &self,
241            session_id: acp::SessionId,
242            update: acp::SessionUpdate,
243            cx: &mut App,
244        ) {
245            self.sessions
246                .lock()
247                .get(&session_id)
248                .unwrap()
249                .update(cx, |thread, cx| {
250                    thread.handle_session_update(update.clone(), cx).unwrap();
251                })
252                .unwrap();
253        }
254    }
255
256    impl AgentConnection for StubAgentConnection {
257        fn auth_methods(&self) -> &[acp::AuthMethod] {
258            &[]
259        }
260
261        fn new_thread(
262            self: Rc<Self>,
263            project: Entity<Project>,
264            _cwd: &Path,
265            cx: &mut gpui::App,
266        ) -> Task<gpui::Result<Entity<AcpThread>>> {
267            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
268            let thread =
269                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
270            self.sessions.lock().insert(session_id, thread.downgrade());
271            Task::ready(Ok(thread))
272        }
273
274        fn authenticate(
275            &self,
276            _method_id: acp::AuthMethodId,
277            _cx: &mut App,
278        ) -> Task<gpui::Result<()>> {
279            unimplemented!()
280        }
281
282        fn prompt(
283            &self,
284            _id: Option<UserMessageId>,
285            params: acp::PromptRequest,
286            cx: &mut App,
287        ) -> Task<gpui::Result<acp::PromptResponse>> {
288            let sessions = self.sessions.lock();
289            let thread = sessions.get(&params.session_id).unwrap();
290            let mut tasks = vec![];
291            for update in self.next_prompt_updates.lock().drain(..) {
292                let thread = thread.clone();
293                let update = update.clone();
294                let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
295                    && let Some(options) = self.permission_requests.get(&tool_call.id)
296                {
297                    Some((tool_call.clone(), options.clone()))
298                } else {
299                    None
300                };
301                let task = cx.spawn(async move |cx| {
302                    if let Some((tool_call, options)) = permission_request {
303                        let permission = thread.update(cx, |thread, cx| {
304                            thread.request_tool_call_authorization(
305                                tool_call.clone().into(),
306                                options.clone(),
307                                cx,
308                            )
309                        })?;
310                        permission?.await?;
311                    }
312                    thread.update(cx, |thread, cx| {
313                        thread.handle_session_update(update.clone(), cx).unwrap();
314                    })?;
315                    anyhow::Ok(())
316                });
317                tasks.push(task);
318            }
319            cx.spawn(async move |_| {
320                try_join_all(tasks).await?;
321                Ok(acp::PromptResponse {
322                    stop_reason: acp::StopReason::EndTurn,
323                })
324            })
325        }
326
327        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
328            unimplemented!()
329        }
330
331        fn session_editor(
332            &self,
333            _session_id: &agent_client_protocol::SessionId,
334            _cx: &mut App,
335        ) -> Option<Rc<dyn AgentSessionEditor>> {
336            Some(Rc::new(StubAgentSessionEditor))
337        }
338
339        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
340            self
341        }
342    }
343
344    struct StubAgentSessionEditor;
345
346    impl AgentSessionEditor for StubAgentSessionEditor {
347        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
348            Task::ready(Ok(()))
349        }
350    }
351}
352
353#[cfg(feature = "test-support")]
354pub use test_support::*;