connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use project::Project;
  7use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
  8use ui::{App, IconName};
  9use uuid::Uuid;
 10
 11#[derive(Clone, Debug, Eq, PartialEq)]
 12pub struct UserMessageId(Arc<str>);
 13
 14impl UserMessageId {
 15    pub fn new() -> Self {
 16        Self(Uuid::new_v4().to_string().into())
 17    }
 18}
 19
 20pub trait AgentConnection {
 21    fn new_thread(
 22        self: Rc<Self>,
 23        project: Entity<Project>,
 24        cwd: &Path,
 25        cx: &mut App,
 26    ) -> Task<Result<Entity<AcpThread>>>;
 27
 28    fn auth_methods(&self) -> &[acp::AuthMethod];
 29
 30    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 31
 32    fn prompt(
 33        &self,
 34        user_message_id: Option<UserMessageId>,
 35        params: acp::PromptRequest,
 36        cx: &mut App,
 37    ) -> Task<Result<acp::PromptResponse>>;
 38
 39    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 40
 41    fn session_editor(
 42        &self,
 43        _session_id: &acp::SessionId,
 44        _cx: &mut App,
 45    ) -> Option<Rc<dyn AgentSessionEditor>> {
 46        None
 47    }
 48
 49    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 50    ///
 51    /// If the agent does not support model selection, returns [None].
 52    /// This allows sharing the selector in UI components.
 53    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 54        None
 55    }
 56}
 57
 58pub trait AgentSessionEditor {
 59    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 60}
 61
 62#[derive(Debug)]
 63pub struct AuthRequired;
 64
 65impl Error for AuthRequired {}
 66impl fmt::Display for AuthRequired {
 67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 68        write!(f, "AuthRequired")
 69    }
 70}
 71
 72/// Trait for agents that support listing, selecting, and querying language models.
 73///
 74/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
 75pub trait AgentModelSelector: 'static {
 76    /// Lists all available language models for this agent.
 77    ///
 78    /// # Parameters
 79    /// - `cx`: The GPUI app context for async operations and global access.
 80    ///
 81    /// # Returns
 82    /// A task resolving to the list of models or an error (e.g., if no models are configured).
 83    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
 84
 85    /// Selects a model for a specific session (thread).
 86    ///
 87    /// This sets the default model for future interactions in the session.
 88    /// If the session doesn't exist or the model is invalid, it returns an error.
 89    ///
 90    /// # Parameters
 91    /// - `session_id`: The ID of the session (thread) to apply the model to.
 92    /// - `model`: The model to select (should be one from [list_models]).
 93    /// - `cx`: The GPUI app context.
 94    ///
 95    /// # Returns
 96    /// A task resolving to `Ok(())` on success or an error.
 97    fn select_model(
 98        &self,
 99        session_id: acp::SessionId,
100        model_id: AgentModelId,
101        cx: &mut App,
102    ) -> Task<Result<()>>;
103
104    /// Retrieves the currently selected model for a specific session (thread).
105    ///
106    /// # Parameters
107    /// - `session_id`: The ID of the session (thread) to query.
108    /// - `cx`: The GPUI app context.
109    ///
110    /// # Returns
111    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
112    fn selected_model(
113        &self,
114        session_id: &acp::SessionId,
115        cx: &mut App,
116    ) -> Task<Result<AgentModelInfo>>;
117
118    /// Whenever the model list is updated the receiver will be notified.
119    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
120}
121
122#[derive(Debug, Clone, PartialEq, Eq, Hash)]
123pub struct AgentModelId(pub SharedString);
124
125impl std::ops::Deref for AgentModelId {
126    type Target = SharedString;
127
128    fn deref(&self) -> &Self::Target {
129        &self.0
130    }
131}
132
133impl fmt::Display for AgentModelId {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        self.0.fmt(f)
136    }
137}
138
139#[derive(Debug, Clone, PartialEq, Eq)]
140pub struct AgentModelInfo {
141    pub id: AgentModelId,
142    pub name: SharedString,
143    pub icon: Option<IconName>,
144}
145
146#[derive(Debug, Clone, PartialEq, Eq, Hash)]
147pub struct AgentModelGroupName(pub SharedString);
148
149#[derive(Debug, Clone)]
150pub enum AgentModelList {
151    Flat(Vec<AgentModelInfo>),
152    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
153}
154
155impl AgentModelList {
156    pub fn is_empty(&self) -> bool {
157        match self {
158            AgentModelList::Flat(models) => models.is_empty(),
159            AgentModelList::Grouped(groups) => groups.is_empty(),
160        }
161    }
162}
163
164#[cfg(feature = "test-support")]
165mod test_support {
166    use std::sync::Arc;
167
168    use collections::HashMap;
169    use futures::future::try_join_all;
170    use gpui::{AppContext as _, WeakEntity};
171    use parking_lot::Mutex;
172
173    use super::*;
174
175    #[derive(Clone, Default)]
176    pub struct StubAgentConnection {
177        sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
178        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
179        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
180    }
181
182    impl StubAgentConnection {
183        pub fn new() -> Self {
184            Self {
185                next_prompt_updates: Default::default(),
186                permission_requests: HashMap::default(),
187                sessions: Arc::default(),
188            }
189        }
190
191        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
192            *self.next_prompt_updates.lock() = updates;
193        }
194
195        pub fn with_permission_requests(
196            mut self,
197            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
198        ) -> Self {
199            self.permission_requests = permission_requests;
200            self
201        }
202
203        pub fn send_update(
204            &self,
205            session_id: acp::SessionId,
206            update: acp::SessionUpdate,
207            cx: &mut App,
208        ) {
209            self.sessions
210                .lock()
211                .get(&session_id)
212                .unwrap()
213                .update(cx, |thread, cx| {
214                    thread.handle_session_update(update.clone(), cx).unwrap();
215                })
216                .unwrap();
217        }
218    }
219
220    impl AgentConnection for StubAgentConnection {
221        fn auth_methods(&self) -> &[acp::AuthMethod] {
222            &[]
223        }
224
225        fn new_thread(
226            self: Rc<Self>,
227            project: Entity<Project>,
228            _cwd: &Path,
229            cx: &mut gpui::App,
230        ) -> Task<gpui::Result<Entity<AcpThread>>> {
231            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
232            let thread =
233                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
234            self.sessions.lock().insert(session_id, thread.downgrade());
235            Task::ready(Ok(thread))
236        }
237
238        fn authenticate(
239            &self,
240            _method_id: acp::AuthMethodId,
241            _cx: &mut App,
242        ) -> Task<gpui::Result<()>> {
243            unimplemented!()
244        }
245
246        fn prompt(
247            &self,
248            _id: Option<UserMessageId>,
249            params: acp::PromptRequest,
250            cx: &mut App,
251        ) -> Task<gpui::Result<acp::PromptResponse>> {
252            let sessions = self.sessions.lock();
253            let thread = sessions.get(&params.session_id).unwrap();
254            let mut tasks = vec![];
255            for update in self.next_prompt_updates.lock().drain(..) {
256                let thread = thread.clone();
257                let update = update.clone();
258                let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
259                    && let Some(options) = self.permission_requests.get(&tool_call.id)
260                {
261                    Some((tool_call.clone(), options.clone()))
262                } else {
263                    None
264                };
265                let task = cx.spawn(async move |cx| {
266                    if let Some((tool_call, options)) = permission_request {
267                        let permission = thread.update(cx, |thread, cx| {
268                            thread.request_tool_call_authorization(
269                                tool_call.clone(),
270                                options.clone(),
271                                cx,
272                            )
273                        })?;
274                        permission.await?;
275                    }
276                    thread.update(cx, |thread, cx| {
277                        thread.handle_session_update(update.clone(), cx).unwrap();
278                    })?;
279                    anyhow::Ok(())
280                });
281                tasks.push(task);
282            }
283            cx.spawn(async move |_| {
284                try_join_all(tasks).await?;
285                Ok(acp::PromptResponse {
286                    stop_reason: acp::StopReason::EndTurn,
287                })
288            })
289        }
290
291        fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
292            unimplemented!()
293        }
294
295        fn session_editor(
296            &self,
297            _session_id: &agent_client_protocol::SessionId,
298            _cx: &mut App,
299        ) -> Option<Rc<dyn AgentSessionEditor>> {
300            Some(Rc::new(StubAgentSessionEditor))
301        }
302    }
303
304    struct StubAgentSessionEditor;
305
306    impl AgentSessionEditor for StubAgentSessionEditor {
307        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
308            Task::ready(Ok(()))
309        }
310    }
311}
312
313#[cfg(feature = "test-support")]
314pub use test_support::*;