connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn auth_methods(&self) -> &[acp::AuthMethod];
 31
 32    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 33
 34    fn prompt(
 35        &self,
 36        user_message_id: Option<UserMessageId>,
 37        params: acp::PromptRequest,
 38        cx: &mut App,
 39    ) -> Task<Result<acp::PromptResponse>>;
 40
 41    fn prompt_capabilities(&self) -> acp::PromptCapabilities;
 42
 43    fn resume(
 44        &self,
 45        _session_id: &acp::SessionId,
 46        _cx: &mut App,
 47    ) -> Option<Rc<dyn AgentSessionResume>> {
 48        None
 49    }
 50
 51    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 52
 53    fn truncate(
 54        &self,
 55        _session_id: &acp::SessionId,
 56        _cx: &mut App,
 57    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 58        None
 59    }
 60
 61    fn set_title(
 62        &self,
 63        _session_id: &acp::SessionId,
 64        _cx: &mut App,
 65    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 66        None
 67    }
 68
 69    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 70    ///
 71    /// If the agent does not support model selection, returns [None].
 72    /// This allows sharing the selector in UI components.
 73    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 74        None
 75    }
 76
 77    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 78        None
 79    }
 80
 81    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 82}
 83
 84impl dyn AgentConnection {
 85    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 86        self.into_any().downcast().ok()
 87    }
 88}
 89
 90pub trait AgentSessionTruncate {
 91    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 92}
 93
 94pub trait AgentSessionResume {
 95    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 96}
 97
 98pub trait AgentSessionSetTitle {
 99    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
100}
101
102pub trait AgentTelemetry {
103    /// The name of the agent used for telemetry.
104    fn agent_name(&self) -> String;
105
106    /// A representation of the current thread state that can be serialized for
107    /// storage with telemetry events.
108    fn thread_data(
109        &self,
110        session_id: &acp::SessionId,
111        cx: &mut App,
112    ) -> Task<Result<serde_json::Value>>;
113}
114
115#[derive(Debug)]
116pub struct AuthRequired {
117    pub description: Option<String>,
118    pub provider_id: Option<LanguageModelProviderId>,
119}
120
121impl AuthRequired {
122    pub fn new() -> Self {
123        Self {
124            description: None,
125            provider_id: None,
126        }
127    }
128
129    pub fn with_description(mut self, description: String) -> Self {
130        self.description = Some(description);
131        self
132    }
133
134    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
135        self.provider_id = Some(provider_id);
136        self
137    }
138}
139
140impl Error for AuthRequired {}
141impl fmt::Display for AuthRequired {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        write!(f, "Authentication required")
144    }
145}
146
147/// Trait for agents that support listing, selecting, and querying language models.
148///
149/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
150pub trait AgentModelSelector: 'static {
151    /// Lists all available language models for this agent.
152    ///
153    /// # Parameters
154    /// - `cx`: The GPUI app context for async operations and global access.
155    ///
156    /// # Returns
157    /// A task resolving to the list of models or an error (e.g., if no models are configured).
158    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
159
160    /// Selects a model for a specific session (thread).
161    ///
162    /// This sets the default model for future interactions in the session.
163    /// If the session doesn't exist or the model is invalid, it returns an error.
164    ///
165    /// # Parameters
166    /// - `session_id`: The ID of the session (thread) to apply the model to.
167    /// - `model`: The model to select (should be one from [list_models]).
168    /// - `cx`: The GPUI app context.
169    ///
170    /// # Returns
171    /// A task resolving to `Ok(())` on success or an error.
172    fn select_model(
173        &self,
174        session_id: acp::SessionId,
175        model_id: AgentModelId,
176        cx: &mut App,
177    ) -> Task<Result<()>>;
178
179    /// Retrieves the currently selected model for a specific session (thread).
180    ///
181    /// # Parameters
182    /// - `session_id`: The ID of the session (thread) to query.
183    /// - `cx`: The GPUI app context.
184    ///
185    /// # Returns
186    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
187    fn selected_model(
188        &self,
189        session_id: &acp::SessionId,
190        cx: &mut App,
191    ) -> Task<Result<AgentModelInfo>>;
192
193    /// Whenever the model list is updated the receiver will be notified.
194    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
195}
196
197#[derive(Debug, Clone, PartialEq, Eq, Hash)]
198pub struct AgentModelId(pub SharedString);
199
200impl std::ops::Deref for AgentModelId {
201    type Target = SharedString;
202
203    fn deref(&self) -> &Self::Target {
204        &self.0
205    }
206}
207
208impl fmt::Display for AgentModelId {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        self.0.fmt(f)
211    }
212}
213
214#[derive(Debug, Clone, PartialEq, Eq)]
215pub struct AgentModelInfo {
216    pub id: AgentModelId,
217    pub name: SharedString,
218    pub icon: Option<IconName>,
219}
220
221#[derive(Debug, Clone, PartialEq, Eq, Hash)]
222pub struct AgentModelGroupName(pub SharedString);
223
224#[derive(Debug, Clone)]
225pub enum AgentModelList {
226    Flat(Vec<AgentModelInfo>),
227    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
228}
229
230impl AgentModelList {
231    pub fn is_empty(&self) -> bool {
232        match self {
233            AgentModelList::Flat(models) => models.is_empty(),
234            AgentModelList::Grouped(groups) => groups.is_empty(),
235        }
236    }
237}
238
239#[cfg(feature = "test-support")]
240mod test_support {
241    use std::sync::Arc;
242
243    use action_log::ActionLog;
244    use collections::HashMap;
245    use futures::{channel::oneshot, future::try_join_all};
246    use gpui::{AppContext as _, WeakEntity};
247    use parking_lot::Mutex;
248
249    use super::*;
250
251    #[derive(Clone, Default)]
252    pub struct StubAgentConnection {
253        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
254        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
255        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
256    }
257
258    struct Session {
259        thread: WeakEntity<AcpThread>,
260        response_tx: Option<oneshot::Sender<acp::StopReason>>,
261    }
262
263    impl StubAgentConnection {
264        pub fn new() -> Self {
265            Self {
266                next_prompt_updates: Default::default(),
267                permission_requests: HashMap::default(),
268                sessions: Arc::default(),
269            }
270        }
271
272        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
273            *self.next_prompt_updates.lock() = updates;
274        }
275
276        pub fn with_permission_requests(
277            mut self,
278            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
279        ) -> Self {
280            self.permission_requests = permission_requests;
281            self
282        }
283
284        pub fn send_update(
285            &self,
286            session_id: acp::SessionId,
287            update: acp::SessionUpdate,
288            cx: &mut App,
289        ) {
290            assert!(
291                self.next_prompt_updates.lock().is_empty(),
292                "Use either send_update or set_next_prompt_updates"
293            );
294
295            self.sessions
296                .lock()
297                .get(&session_id)
298                .unwrap()
299                .thread
300                .update(cx, |thread, cx| {
301                    thread.handle_session_update(update, cx).unwrap();
302                })
303                .unwrap();
304        }
305
306        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
307            self.sessions
308                .lock()
309                .get_mut(&session_id)
310                .unwrap()
311                .response_tx
312                .take()
313                .expect("No pending turn")
314                .send(stop_reason)
315                .unwrap();
316        }
317    }
318
319    impl AgentConnection for StubAgentConnection {
320        fn auth_methods(&self) -> &[acp::AuthMethod] {
321            &[]
322        }
323
324        fn new_thread(
325            self: Rc<Self>,
326            project: Entity<Project>,
327            _cwd: &Path,
328            cx: &mut gpui::App,
329        ) -> Task<gpui::Result<Entity<AcpThread>>> {
330            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
331            let action_log = cx.new(|_| ActionLog::new(project.clone()));
332            let thread = cx.new(|_cx| {
333                AcpThread::new(
334                    "Test",
335                    self.clone(),
336                    project,
337                    action_log,
338                    session_id.clone(),
339                )
340            });
341            self.sessions.lock().insert(
342                session_id,
343                Session {
344                    thread: thread.downgrade(),
345                    response_tx: None,
346                },
347            );
348            Task::ready(Ok(thread))
349        }
350
351        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
352            acp::PromptCapabilities {
353                image: true,
354                audio: true,
355                embedded_context: true,
356            }
357        }
358
359        fn authenticate(
360            &self,
361            _method_id: acp::AuthMethodId,
362            _cx: &mut App,
363        ) -> Task<gpui::Result<()>> {
364            unimplemented!()
365        }
366
367        fn prompt(
368            &self,
369            _id: Option<UserMessageId>,
370            params: acp::PromptRequest,
371            cx: &mut App,
372        ) -> Task<gpui::Result<acp::PromptResponse>> {
373            let mut sessions = self.sessions.lock();
374            let Session {
375                thread,
376                response_tx,
377            } = sessions.get_mut(&params.session_id).unwrap();
378            let mut tasks = vec![];
379            if self.next_prompt_updates.lock().is_empty() {
380                let (tx, rx) = oneshot::channel();
381                response_tx.replace(tx);
382                cx.spawn(async move |_| {
383                    let stop_reason = rx.await?;
384                    Ok(acp::PromptResponse { stop_reason })
385                })
386            } else {
387                for update in self.next_prompt_updates.lock().drain(..) {
388                    let thread = thread.clone();
389                    let update = update.clone();
390                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
391                        &update
392                        && let Some(options) = self.permission_requests.get(&tool_call.id)
393                    {
394                        Some((tool_call.clone(), options.clone()))
395                    } else {
396                        None
397                    };
398                    let task = cx.spawn(async move |cx| {
399                        if let Some((tool_call, options)) = permission_request {
400                            let permission = thread.update(cx, |thread, cx| {
401                                thread.request_tool_call_authorization(
402                                    tool_call.clone().into(),
403                                    options.clone(),
404                                    cx,
405                                )
406                            })?;
407                            permission?.await?;
408                        }
409                        thread.update(cx, |thread, cx| {
410                            thread.handle_session_update(update.clone(), cx).unwrap();
411                        })?;
412                        anyhow::Ok(())
413                    });
414                    tasks.push(task);
415                }
416
417                cx.spawn(async move |_| {
418                    try_join_all(tasks).await?;
419                    Ok(acp::PromptResponse {
420                        stop_reason: acp::StopReason::EndTurn,
421                    })
422                })
423            }
424        }
425
426        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
427            if let Some(end_turn_tx) = self
428                .sessions
429                .lock()
430                .get_mut(session_id)
431                .unwrap()
432                .response_tx
433                .take()
434            {
435                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
436            }
437        }
438
439        fn truncate(
440            &self,
441            _session_id: &agent_client_protocol::SessionId,
442            _cx: &mut App,
443        ) -> Option<Rc<dyn AgentSessionTruncate>> {
444            Some(Rc::new(StubAgentSessionEditor))
445        }
446
447        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
448            self
449        }
450    }
451
452    struct StubAgentSessionEditor;
453
454    impl AgentSessionTruncate for StubAgentSessionEditor {
455        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
456            Task::ready(Ok(()))
457        }
458    }
459}
460
461#[cfg(feature = "test-support")]
462pub use test_support::*;