connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use std::{any::Any, error::Error, fmt, path::Path, rc::Rc};
  9use ui::{App, IconName};
 10
 11pub trait AgentConnection {
 12    fn new_thread(
 13        self: Rc<Self>,
 14        project: Entity<Project>,
 15        cwd: &Path,
 16        cx: &mut App,
 17    ) -> Task<Result<Entity<AcpThread>>>;
 18
 19    fn auth_methods(&self) -> &[acp::AuthMethod];
 20
 21    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 22
 23    fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
 24    -> Task<Result<acp::PromptResponse>>;
 25
 26    fn resume(
 27        &self,
 28        _session_id: &acp::SessionId,
 29        _cx: &App,
 30    ) -> Option<Rc<dyn AgentSessionResume>> {
 31        None
 32    }
 33
 34    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 35
 36    fn rewind(
 37        &self,
 38        _session_id: &acp::SessionId,
 39        _cx: &App,
 40    ) -> Option<Rc<dyn AgentSessionRewind>> {
 41        None
 42    }
 43
 44    fn set_title(
 45        &self,
 46        _session_id: &acp::SessionId,
 47        _cx: &App,
 48    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 49        None
 50    }
 51
 52    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 53    ///
 54    /// If the agent does not support model selection, returns [None].
 55    /// This allows sharing the selector in UI components.
 56    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 57        None
 58    }
 59
 60    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 61        None
 62    }
 63
 64    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 65}
 66
 67impl dyn AgentConnection {
 68    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 69        self.into_any().downcast().ok()
 70    }
 71}
 72
 73pub trait AgentSessionRewind {
 74    fn rewind(&self, message_id: acp::PromptId, cx: &mut App) -> Task<Result<()>>;
 75}
 76
 77pub trait AgentSessionResume {
 78    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 79}
 80
 81pub trait AgentSessionSetTitle {
 82    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
 83}
 84
 85pub trait AgentTelemetry {
 86    /// The name of the agent used for telemetry.
 87    fn agent_name(&self) -> String;
 88
 89    /// A representation of the current thread state that can be serialized for
 90    /// storage with telemetry events.
 91    fn thread_data(
 92        &self,
 93        session_id: &acp::SessionId,
 94        cx: &mut App,
 95    ) -> Task<Result<serde_json::Value>>;
 96}
 97
 98#[derive(Debug)]
 99pub struct AuthRequired {
100    pub description: Option<String>,
101    pub provider_id: Option<LanguageModelProviderId>,
102}
103
104impl AuthRequired {
105    pub fn new() -> Self {
106        Self {
107            description: None,
108            provider_id: None,
109        }
110    }
111
112    pub fn with_description(mut self, description: String) -> Self {
113        self.description = Some(description);
114        self
115    }
116
117    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
118        self.provider_id = Some(provider_id);
119        self
120    }
121}
122
123impl Error for AuthRequired {}
124impl fmt::Display for AuthRequired {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        write!(f, "Authentication required")
127    }
128}
129
130/// Trait for agents that support listing, selecting, and querying language models.
131///
132/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
133pub trait AgentModelSelector: 'static {
134    /// Lists all available language models for this agent.
135    ///
136    /// # Parameters
137    /// - `cx`: The GPUI app context for async operations and global access.
138    ///
139    /// # Returns
140    /// A task resolving to the list of models or an error (e.g., if no models are configured).
141    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
142
143    /// Selects a model for a specific session (thread).
144    ///
145    /// This sets the default model for future interactions in the session.
146    /// If the session doesn't exist or the model is invalid, it returns an error.
147    ///
148    /// # Parameters
149    /// - `session_id`: The ID of the session (thread) to apply the model to.
150    /// - `model`: The model to select (should be one from [list_models]).
151    /// - `cx`: The GPUI app context.
152    ///
153    /// # Returns
154    /// A task resolving to `Ok(())` on success or an error.
155    fn select_model(
156        &self,
157        session_id: acp::SessionId,
158        model_id: AgentModelId,
159        cx: &mut App,
160    ) -> Task<Result<()>>;
161
162    /// Retrieves the currently selected model for a specific session (thread).
163    ///
164    /// # Parameters
165    /// - `session_id`: The ID of the session (thread) to query.
166    /// - `cx`: The GPUI app context.
167    ///
168    /// # Returns
169    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
170    fn selected_model(
171        &self,
172        session_id: &acp::SessionId,
173        cx: &mut App,
174    ) -> Task<Result<AgentModelInfo>>;
175
176    /// Whenever the model list is updated the receiver will be notified.
177    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
178}
179
180#[derive(Debug, Clone, PartialEq, Eq, Hash)]
181pub struct AgentModelId(pub SharedString);
182
183impl std::ops::Deref for AgentModelId {
184    type Target = SharedString;
185
186    fn deref(&self) -> &Self::Target {
187        &self.0
188    }
189}
190
191impl fmt::Display for AgentModelId {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        self.0.fmt(f)
194    }
195}
196
197#[derive(Debug, Clone, PartialEq, Eq)]
198pub struct AgentModelInfo {
199    pub id: AgentModelId,
200    pub name: SharedString,
201    pub icon: Option<IconName>,
202}
203
204#[derive(Debug, Clone, PartialEq, Eq, Hash)]
205pub struct AgentModelGroupName(pub SharedString);
206
207#[derive(Debug, Clone)]
208pub enum AgentModelList {
209    Flat(Vec<AgentModelInfo>),
210    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
211}
212
213impl AgentModelList {
214    pub fn is_empty(&self) -> bool {
215        match self {
216            AgentModelList::Flat(models) => models.is_empty(),
217            AgentModelList::Grouped(groups) => groups.is_empty(),
218        }
219    }
220}
221
222#[cfg(feature = "test-support")]
223mod test_support {
224    use std::sync::Arc;
225
226    use action_log::ActionLog;
227    use collections::HashMap;
228    use futures::{channel::oneshot, future::try_join_all};
229    use gpui::{AppContext as _, WeakEntity};
230    use parking_lot::Mutex;
231
232    use super::*;
233
234    #[derive(Clone, Default)]
235    pub struct StubAgentConnection {
236        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
237        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
238        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
239    }
240
241    struct Session {
242        thread: WeakEntity<AcpThread>,
243        response_tx: Option<oneshot::Sender<acp::StopReason>>,
244    }
245
246    impl StubAgentConnection {
247        pub fn new() -> Self {
248            Self {
249                next_prompt_updates: Default::default(),
250                permission_requests: HashMap::default(),
251                sessions: Arc::default(),
252            }
253        }
254
255        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
256            *self.next_prompt_updates.lock() = updates;
257        }
258
259        pub fn with_permission_requests(
260            mut self,
261            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
262        ) -> Self {
263            self.permission_requests = permission_requests;
264            self
265        }
266
267        pub fn send_update(
268            &self,
269            session_id: acp::SessionId,
270            update: acp::SessionUpdate,
271            cx: &mut App,
272        ) {
273            assert!(
274                self.next_prompt_updates.lock().is_empty(),
275                "Use either send_update or set_next_prompt_updates"
276            );
277
278            self.sessions
279                .lock()
280                .get(&session_id)
281                .unwrap()
282                .thread
283                .update(cx, |thread, cx| {
284                    thread.handle_session_update(update, cx).unwrap();
285                })
286                .unwrap();
287        }
288
289        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
290            self.sessions
291                .lock()
292                .get_mut(&session_id)
293                .unwrap()
294                .response_tx
295                .take()
296                .expect("No pending turn")
297                .send(stop_reason)
298                .unwrap();
299        }
300    }
301
302    impl AgentConnection for StubAgentConnection {
303        fn auth_methods(&self) -> &[acp::AuthMethod] {
304            &[]
305        }
306
307        fn new_thread(
308            self: Rc<Self>,
309            project: Entity<Project>,
310            _cwd: &Path,
311            cx: &mut gpui::App,
312        ) -> Task<gpui::Result<Entity<AcpThread>>> {
313            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
314            let action_log = cx.new(|_| ActionLog::new(project.clone()));
315            let thread = cx.new(|cx| {
316                AcpThread::new(
317                    "Test",
318                    self.clone(),
319                    project,
320                    action_log,
321                    session_id.clone(),
322                    watch::Receiver::constant(acp::PromptCapabilities {
323                        image: true,
324                        audio: true,
325                        embedded_context: true,
326                    }),
327                    cx,
328                )
329            });
330            self.sessions.lock().insert(
331                session_id,
332                Session {
333                    thread: thread.downgrade(),
334                    response_tx: None,
335                },
336            );
337            Task::ready(Ok(thread))
338        }
339
340        fn authenticate(
341            &self,
342            _method_id: acp::AuthMethodId,
343            _cx: &mut App,
344        ) -> Task<gpui::Result<()>> {
345            unimplemented!()
346        }
347
348        fn prompt(
349            &self,
350            params: acp::PromptRequest,
351            cx: &mut App,
352        ) -> Task<gpui::Result<acp::PromptResponse>> {
353            let mut sessions = self.sessions.lock();
354            let Session {
355                thread,
356                response_tx,
357            } = sessions.get_mut(&params.session_id).unwrap();
358            let mut tasks = vec![];
359            if self.next_prompt_updates.lock().is_empty() {
360                let (tx, rx) = oneshot::channel();
361                response_tx.replace(tx);
362                cx.spawn(async move |_| {
363                    let stop_reason = rx.await?;
364                    Ok(acp::PromptResponse { stop_reason })
365                })
366            } else {
367                for update in self.next_prompt_updates.lock().drain(..) {
368                    let thread = thread.clone();
369                    let update = update.clone();
370                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
371                        &update
372                        && let Some(options) = self.permission_requests.get(&tool_call.id)
373                    {
374                        Some((tool_call.clone(), options.clone()))
375                    } else {
376                        None
377                    };
378                    let task = cx.spawn(async move |cx| {
379                        if let Some((tool_call, options)) = permission_request {
380                            let permission = thread.update(cx, |thread, cx| {
381                                thread.request_tool_call_authorization(
382                                    tool_call.clone().into(),
383                                    options.clone(),
384                                    cx,
385                                )
386                            })?;
387                            permission?.await?;
388                        }
389                        thread.update(cx, |thread, cx| {
390                            thread.handle_session_update(update.clone(), cx).unwrap();
391                        })?;
392                        anyhow::Ok(())
393                    });
394                    tasks.push(task);
395                }
396
397                cx.spawn(async move |_| {
398                    try_join_all(tasks).await?;
399                    Ok(acp::PromptResponse {
400                        stop_reason: acp::StopReason::EndTurn,
401                    })
402                })
403            }
404        }
405
406        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
407            if let Some(end_turn_tx) = self
408                .sessions
409                .lock()
410                .get_mut(session_id)
411                .unwrap()
412                .response_tx
413                .take()
414            {
415                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
416            }
417        }
418
419        fn rewind(
420            &self,
421            _session_id: &agent_client_protocol::SessionId,
422            _cx: &App,
423        ) -> Option<Rc<dyn AgentSessionRewind>> {
424            Some(Rc::new(StubAgentSessionEditor))
425        }
426
427        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
428            self
429        }
430    }
431
432    struct StubAgentSessionEditor;
433
434    impl AgentSessionRewind for StubAgentSessionEditor {
435        fn rewind(&self, _: acp::PromptId, _: &mut App) -> Task<Result<()>> {
436            Task::ready(Ok(()))
437        }
438    }
439}
440
441#[cfg(feature = "test-support")]
442pub use test_support::*;