connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
  9use ui::{App, IconName};
 10use uuid::Uuid;
 11
 12#[derive(Clone, Debug, Eq, PartialEq)]
 13pub struct UserMessageId(Arc<str>);
 14
 15impl UserMessageId {
 16    pub fn new() -> Self {
 17        Self(Uuid::new_v4().to_string().into())
 18    }
 19}
 20
 21pub trait AgentConnection {
 22    fn new_thread(
 23        self: Rc<Self>,
 24        project: Entity<Project>,
 25        cwd: &Path,
 26        cx: &mut App,
 27    ) -> Task<Result<Entity<AcpThread>>>;
 28
 29    fn auth_methods(&self) -> &[acp::AuthMethod];
 30
 31    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 32
 33    fn prompt(
 34        &self,
 35        user_message_id: Option<UserMessageId>,
 36        params: acp::PromptRequest,
 37        cx: &mut App,
 38    ) -> Task<Result<acp::PromptResponse>>;
 39
 40    fn resume(
 41        &self,
 42        _session_id: &acp::SessionId,
 43        _cx: &mut App,
 44    ) -> Option<Rc<dyn AgentSessionResume>> {
 45        None
 46    }
 47
 48    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 49
 50    fn session_editor(
 51        &self,
 52        _session_id: &acp::SessionId,
 53        _cx: &mut App,
 54    ) -> Option<Rc<dyn AgentSessionEditor>> {
 55        None
 56    }
 57
 58    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 59    ///
 60    /// If the agent does not support model selection, returns [None].
 61    /// This allows sharing the selector in UI components.
 62    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 63        None
 64    }
 65
 66    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 67}
 68
 69impl dyn AgentConnection {
 70    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 71        self.into_any().downcast().ok()
 72    }
 73}
 74
 75pub trait AgentSessionEditor {
 76    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 77}
 78
 79pub trait AgentSessionResume {
 80    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 81}
 82
 83#[derive(Debug)]
 84pub struct AuthRequired {
 85    pub description: Option<String>,
 86    pub provider_id: Option<LanguageModelProviderId>,
 87}
 88
 89impl AuthRequired {
 90    pub fn new() -> Self {
 91        Self {
 92            description: None,
 93            provider_id: None,
 94        }
 95    }
 96
 97    pub fn with_description(mut self, description: String) -> Self {
 98        self.description = Some(description);
 99        self
100    }
101
102    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
103        self.provider_id = Some(provider_id);
104        self
105    }
106}
107
108impl Error for AuthRequired {}
109impl fmt::Display for AuthRequired {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        write!(f, "Authentication required")
112    }
113}
114
115/// Trait for agents that support listing, selecting, and querying language models.
116///
117/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
118pub trait AgentModelSelector: 'static {
119    /// Lists all available language models for this agent.
120    ///
121    /// # Parameters
122    /// - `cx`: The GPUI app context for async operations and global access.
123    ///
124    /// # Returns
125    /// A task resolving to the list of models or an error (e.g., if no models are configured).
126    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
127
128    /// Selects a model for a specific session (thread).
129    ///
130    /// This sets the default model for future interactions in the session.
131    /// If the session doesn't exist or the model is invalid, it returns an error.
132    ///
133    /// # Parameters
134    /// - `session_id`: The ID of the session (thread) to apply the model to.
135    /// - `model`: The model to select (should be one from [list_models]).
136    /// - `cx`: The GPUI app context.
137    ///
138    /// # Returns
139    /// A task resolving to `Ok(())` on success or an error.
140    fn select_model(
141        &self,
142        session_id: acp::SessionId,
143        model_id: AgentModelId,
144        cx: &mut App,
145    ) -> Task<Result<()>>;
146
147    /// Retrieves the currently selected model for a specific session (thread).
148    ///
149    /// # Parameters
150    /// - `session_id`: The ID of the session (thread) to query.
151    /// - `cx`: The GPUI app context.
152    ///
153    /// # Returns
154    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
155    fn selected_model(
156        &self,
157        session_id: &acp::SessionId,
158        cx: &mut App,
159    ) -> Task<Result<AgentModelInfo>>;
160
161    /// Whenever the model list is updated the receiver will be notified.
162    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
163}
164
165#[derive(Debug, Clone, PartialEq, Eq, Hash)]
166pub struct AgentModelId(pub SharedString);
167
168impl std::ops::Deref for AgentModelId {
169    type Target = SharedString;
170
171    fn deref(&self) -> &Self::Target {
172        &self.0
173    }
174}
175
176impl fmt::Display for AgentModelId {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        self.0.fmt(f)
179    }
180}
181
182#[derive(Debug, Clone, PartialEq, Eq)]
183pub struct AgentModelInfo {
184    pub id: AgentModelId,
185    pub name: SharedString,
186    pub icon: Option<IconName>,
187}
188
189#[derive(Debug, Clone, PartialEq, Eq, Hash)]
190pub struct AgentModelGroupName(pub SharedString);
191
192#[derive(Debug, Clone)]
193pub enum AgentModelList {
194    Flat(Vec<AgentModelInfo>),
195    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
196}
197
198impl AgentModelList {
199    pub fn is_empty(&self) -> bool {
200        match self {
201            AgentModelList::Flat(models) => models.is_empty(),
202            AgentModelList::Grouped(groups) => groups.is_empty(),
203        }
204    }
205}
206
207#[cfg(feature = "test-support")]
208mod test_support {
209    use std::sync::Arc;
210
211    use collections::HashMap;
212    use futures::{channel::oneshot, future::try_join_all};
213    use gpui::{AppContext as _, WeakEntity};
214    use parking_lot::Mutex;
215
216    use super::*;
217
218    #[derive(Clone, Default)]
219    pub struct StubAgentConnection {
220        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
221        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
222        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
223    }
224
225    struct Session {
226        thread: WeakEntity<AcpThread>,
227        response_tx: Option<oneshot::Sender<acp::StopReason>>,
228    }
229
230    impl StubAgentConnection {
231        pub fn new() -> Self {
232            Self {
233                next_prompt_updates: Default::default(),
234                permission_requests: HashMap::default(),
235                sessions: Arc::default(),
236            }
237        }
238
239        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
240            *self.next_prompt_updates.lock() = updates;
241        }
242
243        pub fn with_permission_requests(
244            mut self,
245            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
246        ) -> Self {
247            self.permission_requests = permission_requests;
248            self
249        }
250
251        pub fn send_update(
252            &self,
253            session_id: acp::SessionId,
254            update: acp::SessionUpdate,
255            cx: &mut App,
256        ) {
257            assert!(
258                self.next_prompt_updates.lock().is_empty(),
259                "Use either send_update or set_next_prompt_updates"
260            );
261
262            self.sessions
263                .lock()
264                .get(&session_id)
265                .unwrap()
266                .thread
267                .update(cx, |thread, cx| {
268                    thread.handle_session_update(update, cx).unwrap();
269                })
270                .unwrap();
271        }
272
273        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
274            self.sessions
275                .lock()
276                .get_mut(&session_id)
277                .unwrap()
278                .response_tx
279                .take()
280                .expect("No pending turn")
281                .send(stop_reason)
282                .unwrap();
283        }
284    }
285
286    impl AgentConnection for StubAgentConnection {
287        fn auth_methods(&self) -> &[acp::AuthMethod] {
288            &[]
289        }
290
291        fn new_thread(
292            self: Rc<Self>,
293            project: Entity<Project>,
294            _cwd: &Path,
295            cx: &mut gpui::App,
296        ) -> Task<gpui::Result<Entity<AcpThread>>> {
297            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
298            let thread =
299                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
300            self.sessions.lock().insert(
301                session_id,
302                Session {
303                    thread: thread.downgrade(),
304                    response_tx: None,
305                },
306            );
307            Task::ready(Ok(thread))
308        }
309
310        fn authenticate(
311            &self,
312            _method_id: acp::AuthMethodId,
313            _cx: &mut App,
314        ) -> Task<gpui::Result<()>> {
315            unimplemented!()
316        }
317
318        fn prompt(
319            &self,
320            _id: Option<UserMessageId>,
321            params: acp::PromptRequest,
322            cx: &mut App,
323        ) -> Task<gpui::Result<acp::PromptResponse>> {
324            let mut sessions = self.sessions.lock();
325            let Session {
326                thread,
327                response_tx,
328            } = sessions.get_mut(&params.session_id).unwrap();
329            let mut tasks = vec![];
330            if self.next_prompt_updates.lock().is_empty() {
331                let (tx, rx) = oneshot::channel();
332                response_tx.replace(tx);
333                cx.spawn(async move |_| {
334                    let stop_reason = rx.await?;
335                    Ok(acp::PromptResponse { stop_reason })
336                })
337            } else {
338                for update in self.next_prompt_updates.lock().drain(..) {
339                    let thread = thread.clone();
340                    let update = update.clone();
341                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
342                        &update
343                        && let Some(options) = self.permission_requests.get(&tool_call.id)
344                    {
345                        Some((tool_call.clone(), options.clone()))
346                    } else {
347                        None
348                    };
349                    let task = cx.spawn(async move |cx| {
350                        if let Some((tool_call, options)) = permission_request {
351                            let permission = thread.update(cx, |thread, cx| {
352                                thread.request_tool_call_authorization(
353                                    tool_call.clone().into(),
354                                    options.clone(),
355                                    cx,
356                                )
357                            })?;
358                            permission?.await?;
359                        }
360                        thread.update(cx, |thread, cx| {
361                            thread.handle_session_update(update.clone(), cx).unwrap();
362                        })?;
363                        anyhow::Ok(())
364                    });
365                    tasks.push(task);
366                }
367
368                cx.spawn(async move |_| {
369                    try_join_all(tasks).await?;
370                    Ok(acp::PromptResponse {
371                        stop_reason: acp::StopReason::EndTurn,
372                    })
373                })
374            }
375        }
376
377        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
378            if let Some(end_turn_tx) = self
379                .sessions
380                .lock()
381                .get_mut(session_id)
382                .unwrap()
383                .response_tx
384                .take()
385            {
386                end_turn_tx.send(acp::StopReason::Canceled).unwrap();
387            }
388        }
389
390        fn session_editor(
391            &self,
392            _session_id: &agent_client_protocol::SessionId,
393            _cx: &mut App,
394        ) -> Option<Rc<dyn AgentSessionEditor>> {
395            Some(Rc::new(StubAgentSessionEditor))
396        }
397
398        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
399            self
400        }
401    }
402
403    struct StubAgentSessionEditor;
404
405    impl AgentSessionEditor for StubAgentSessionEditor {
406        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
407            Task::ready(Ok(()))
408        }
409    }
410}
411
412#[cfg(feature = "test-support")]
413pub use test_support::*;