connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn auth_methods(&self) -> &[acp::AuthMethod];
 31
 32    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 33
 34    fn prompt(
 35        &self,
 36        user_message_id: Option<UserMessageId>,
 37        params: acp::PromptRequest,
 38        cx: &mut App,
 39    ) -> Task<Result<acp::PromptResponse>>;
 40
 41    fn resume(
 42        &self,
 43        _session_id: &acp::SessionId,
 44        _cx: &mut App,
 45    ) -> Option<Rc<dyn AgentSessionResume>> {
 46        None
 47    }
 48
 49    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 50
 51    fn session_editor(
 52        &self,
 53        _session_id: &acp::SessionId,
 54        _cx: &mut App,
 55    ) -> Option<Rc<dyn AgentSessionEditor>> {
 56        None
 57    }
 58
 59    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 60    ///
 61    /// If the agent does not support model selection, returns [None].
 62    /// This allows sharing the selector in UI components.
 63    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 64        None
 65    }
 66
 67    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 68}
 69
 70impl dyn AgentConnection {
 71    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 72        self.into_any().downcast().ok()
 73    }
 74}
 75
 76pub trait AgentSessionEditor {
 77    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 78}
 79
 80pub trait AgentSessionResume {
 81    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 82}
 83
 84#[derive(Debug)]
 85pub struct AuthRequired {
 86    pub description: Option<String>,
 87    pub provider_id: Option<LanguageModelProviderId>,
 88}
 89
 90impl AuthRequired {
 91    pub fn new() -> Self {
 92        Self {
 93            description: None,
 94            provider_id: None,
 95        }
 96    }
 97
 98    pub fn with_description(mut self, description: String) -> Self {
 99        self.description = Some(description);
100        self
101    }
102
103    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
104        self.provider_id = Some(provider_id);
105        self
106    }
107}
108
109impl Error for AuthRequired {}
110impl fmt::Display for AuthRequired {
111    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
112        write!(f, "Authentication required")
113    }
114}
115
116/// Trait for agents that support listing, selecting, and querying language models.
117///
118/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
119pub trait AgentModelSelector: 'static {
120    /// Lists all available language models for this agent.
121    ///
122    /// # Parameters
123    /// - `cx`: The GPUI app context for async operations and global access.
124    ///
125    /// # Returns
126    /// A task resolving to the list of models or an error (e.g., if no models are configured).
127    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
128
129    /// Selects a model for a specific session (thread).
130    ///
131    /// This sets the default model for future interactions in the session.
132    /// If the session doesn't exist or the model is invalid, it returns an error.
133    ///
134    /// # Parameters
135    /// - `session_id`: The ID of the session (thread) to apply the model to.
136    /// - `model`: The model to select (should be one from [list_models]).
137    /// - `cx`: The GPUI app context.
138    ///
139    /// # Returns
140    /// A task resolving to `Ok(())` on success or an error.
141    fn select_model(
142        &self,
143        session_id: acp::SessionId,
144        model_id: AgentModelId,
145        cx: &mut App,
146    ) -> Task<Result<()>>;
147
148    /// Retrieves the currently selected model for a specific session (thread).
149    ///
150    /// # Parameters
151    /// - `session_id`: The ID of the session (thread) to query.
152    /// - `cx`: The GPUI app context.
153    ///
154    /// # Returns
155    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
156    fn selected_model(
157        &self,
158        session_id: &acp::SessionId,
159        cx: &mut App,
160    ) -> Task<Result<AgentModelInfo>>;
161
162    /// Whenever the model list is updated the receiver will be notified.
163    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
164}
165
166#[derive(Debug, Clone, PartialEq, Eq, Hash)]
167pub struct AgentModelId(pub SharedString);
168
169impl std::ops::Deref for AgentModelId {
170    type Target = SharedString;
171
172    fn deref(&self) -> &Self::Target {
173        &self.0
174    }
175}
176
177impl fmt::Display for AgentModelId {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        self.0.fmt(f)
180    }
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct AgentModelInfo {
185    pub id: AgentModelId,
186    pub name: SharedString,
187    pub icon: Option<IconName>,
188}
189
190#[derive(Debug, Clone, PartialEq, Eq, Hash)]
191pub struct AgentModelGroupName(pub SharedString);
192
193#[derive(Debug, Clone)]
194pub enum AgentModelList {
195    Flat(Vec<AgentModelInfo>),
196    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
197}
198
199impl AgentModelList {
200    pub fn is_empty(&self) -> bool {
201        match self {
202            AgentModelList::Flat(models) => models.is_empty(),
203            AgentModelList::Grouped(groups) => groups.is_empty(),
204        }
205    }
206}
207
208#[cfg(feature = "test-support")]
209mod test_support {
210    use std::sync::Arc;
211
212    use action_log::ActionLog;
213    use collections::HashMap;
214    use futures::{channel::oneshot, future::try_join_all};
215    use gpui::{AppContext as _, WeakEntity};
216    use parking_lot::Mutex;
217
218    use super::*;
219
220    #[derive(Clone, Default)]
221    pub struct StubAgentConnection {
222        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
223        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
224        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
225    }
226
227    struct Session {
228        thread: WeakEntity<AcpThread>,
229        response_tx: Option<oneshot::Sender<acp::StopReason>>,
230    }
231
232    impl StubAgentConnection {
233        pub fn new() -> Self {
234            Self {
235                next_prompt_updates: Default::default(),
236                permission_requests: HashMap::default(),
237                sessions: Arc::default(),
238            }
239        }
240
241        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
242            *self.next_prompt_updates.lock() = updates;
243        }
244
245        pub fn with_permission_requests(
246            mut self,
247            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
248        ) -> Self {
249            self.permission_requests = permission_requests;
250            self
251        }
252
253        pub fn send_update(
254            &self,
255            session_id: acp::SessionId,
256            update: acp::SessionUpdate,
257            cx: &mut App,
258        ) {
259            assert!(
260                self.next_prompt_updates.lock().is_empty(),
261                "Use either send_update or set_next_prompt_updates"
262            );
263
264            self.sessions
265                .lock()
266                .get(&session_id)
267                .unwrap()
268                .thread
269                .update(cx, |thread, cx| {
270                    thread.handle_session_update(update, cx).unwrap();
271                })
272                .unwrap();
273        }
274
275        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
276            self.sessions
277                .lock()
278                .get_mut(&session_id)
279                .unwrap()
280                .response_tx
281                .take()
282                .expect("No pending turn")
283                .send(stop_reason)
284                .unwrap();
285        }
286    }
287
288    impl AgentConnection for StubAgentConnection {
289        fn auth_methods(&self) -> &[acp::AuthMethod] {
290            &[]
291        }
292
293        fn new_thread(
294            self: Rc<Self>,
295            project: Entity<Project>,
296            _cwd: &Path,
297            cx: &mut gpui::App,
298        ) -> Task<gpui::Result<Entity<AcpThread>>> {
299            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
300            let action_log = cx.new(|_| ActionLog::new(project.clone()));
301            let thread = cx.new(|_cx| {
302                AcpThread::new(
303                    "Test",
304                    self.clone(),
305                    project,
306                    action_log,
307                    session_id.clone(),
308                )
309            });
310            self.sessions.lock().insert(
311                session_id,
312                Session {
313                    thread: thread.downgrade(),
314                    response_tx: None,
315                },
316            );
317            Task::ready(Ok(thread))
318        }
319
320        fn authenticate(
321            &self,
322            _method_id: acp::AuthMethodId,
323            _cx: &mut App,
324        ) -> Task<gpui::Result<()>> {
325            unimplemented!()
326        }
327
328        fn prompt(
329            &self,
330            _id: Option<UserMessageId>,
331            params: acp::PromptRequest,
332            cx: &mut App,
333        ) -> Task<gpui::Result<acp::PromptResponse>> {
334            let mut sessions = self.sessions.lock();
335            let Session {
336                thread,
337                response_tx,
338            } = sessions.get_mut(&params.session_id).unwrap();
339            let mut tasks = vec![];
340            if self.next_prompt_updates.lock().is_empty() {
341                let (tx, rx) = oneshot::channel();
342                response_tx.replace(tx);
343                cx.spawn(async move |_| {
344                    let stop_reason = rx.await?;
345                    Ok(acp::PromptResponse { stop_reason })
346                })
347            } else {
348                for update in self.next_prompt_updates.lock().drain(..) {
349                    let thread = thread.clone();
350                    let update = update.clone();
351                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
352                        &update
353                        && let Some(options) = self.permission_requests.get(&tool_call.id)
354                    {
355                        Some((tool_call.clone(), options.clone()))
356                    } else {
357                        None
358                    };
359                    let task = cx.spawn(async move |cx| {
360                        if let Some((tool_call, options)) = permission_request {
361                            let permission = thread.update(cx, |thread, cx| {
362                                thread.request_tool_call_authorization(
363                                    tool_call.clone().into(),
364                                    options.clone(),
365                                    cx,
366                                )
367                            })?;
368                            permission?.await?;
369                        }
370                        thread.update(cx, |thread, cx| {
371                            thread.handle_session_update(update.clone(), cx).unwrap();
372                        })?;
373                        anyhow::Ok(())
374                    });
375                    tasks.push(task);
376                }
377
378                cx.spawn(async move |_| {
379                    try_join_all(tasks).await?;
380                    Ok(acp::PromptResponse {
381                        stop_reason: acp::StopReason::EndTurn,
382                    })
383                })
384            }
385        }
386
387        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
388            if let Some(end_turn_tx) = self
389                .sessions
390                .lock()
391                .get_mut(session_id)
392                .unwrap()
393                .response_tx
394                .take()
395            {
396                end_turn_tx.send(acp::StopReason::Canceled).unwrap();
397            }
398        }
399
400        fn session_editor(
401            &self,
402            _session_id: &agent_client_protocol::SessionId,
403            _cx: &mut App,
404        ) -> Option<Rc<dyn AgentSessionEditor>> {
405            Some(Rc::new(StubAgentSessionEditor))
406        }
407
408        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
409            self
410        }
411    }
412
413    struct StubAgentSessionEditor;
414
415    impl AgentSessionEditor for StubAgentSessionEditor {
416        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
417            Task::ready(Ok(()))
418        }
419    }
420}
421
422#[cfg(feature = "test-support")]
423pub use test_support::*;