connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn auth_methods(&self) -> &[acp::AuthMethod];
 31
 32    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 33
 34    fn prompt(
 35        &self,
 36        user_message_id: Option<UserMessageId>,
 37        params: acp::PromptRequest,
 38        cx: &mut App,
 39    ) -> Task<Result<acp::PromptResponse>>;
 40
 41    fn resume(
 42        &self,
 43        _session_id: &acp::SessionId,
 44        _cx: &App,
 45    ) -> Option<Rc<dyn AgentSessionResume>> {
 46        None
 47    }
 48
 49    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 50
 51    fn truncate(
 52        &self,
 53        _session_id: &acp::SessionId,
 54        _cx: &App,
 55    ) -> Option<Rc<dyn AgentSessionTruncate>> {
 56        None
 57    }
 58
 59    fn set_title(
 60        &self,
 61        _session_id: &acp::SessionId,
 62        _cx: &App,
 63    ) -> Option<Rc<dyn AgentSessionSetTitle>> {
 64        None
 65    }
 66
 67    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 68    ///
 69    /// If the agent does not support model selection, returns [None].
 70    /// This allows sharing the selector in UI components.
 71    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 72        None
 73    }
 74
 75    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 76        None
 77    }
 78
 79    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 80}
 81
 82impl dyn AgentConnection {
 83    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 84        self.into_any().downcast().ok()
 85    }
 86}
 87
 88pub trait AgentSessionTruncate {
 89    fn run(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 90}
 91
 92pub trait AgentSessionResume {
 93    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 94}
 95
 96pub trait AgentSessionSetTitle {
 97    fn run(&self, title: SharedString, cx: &mut App) -> Task<Result<()>>;
 98}
 99
100pub trait AgentTelemetry {
101    /// The name of the agent used for telemetry.
102    fn agent_name(&self) -> String;
103
104    /// A representation of the current thread state that can be serialized for
105    /// storage with telemetry events.
106    fn thread_data(
107        &self,
108        session_id: &acp::SessionId,
109        cx: &mut App,
110    ) -> Task<Result<serde_json::Value>>;
111}
112
113#[derive(Debug)]
114pub struct AuthRequired {
115    pub description: Option<String>,
116    pub provider_id: Option<LanguageModelProviderId>,
117}
118
119impl AuthRequired {
120    pub fn new() -> Self {
121        Self {
122            description: None,
123            provider_id: None,
124        }
125    }
126
127    pub fn with_description(mut self, description: String) -> Self {
128        self.description = Some(description);
129        self
130    }
131
132    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
133        self.provider_id = Some(provider_id);
134        self
135    }
136}
137
138impl Error for AuthRequired {}
139impl fmt::Display for AuthRequired {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        write!(f, "Authentication required")
142    }
143}
144
145/// Trait for agents that support listing, selecting, and querying language models.
146///
147/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
148pub trait AgentModelSelector: 'static {
149    /// Lists all available language models for this agent.
150    ///
151    /// # Parameters
152    /// - `cx`: The GPUI app context for async operations and global access.
153    ///
154    /// # Returns
155    /// A task resolving to the list of models or an error (e.g., if no models are configured).
156    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
157
158    /// Selects a model for a specific session (thread).
159    ///
160    /// This sets the default model for future interactions in the session.
161    /// If the session doesn't exist or the model is invalid, it returns an error.
162    ///
163    /// # Parameters
164    /// - `session_id`: The ID of the session (thread) to apply the model to.
165    /// - `model`: The model to select (should be one from [list_models]).
166    /// - `cx`: The GPUI app context.
167    ///
168    /// # Returns
169    /// A task resolving to `Ok(())` on success or an error.
170    fn select_model(
171        &self,
172        session_id: acp::SessionId,
173        model_id: AgentModelId,
174        cx: &mut App,
175    ) -> Task<Result<()>>;
176
177    /// Retrieves the currently selected model for a specific session (thread).
178    ///
179    /// # Parameters
180    /// - `session_id`: The ID of the session (thread) to query.
181    /// - `cx`: The GPUI app context.
182    ///
183    /// # Returns
184    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
185    fn selected_model(
186        &self,
187        session_id: &acp::SessionId,
188        cx: &mut App,
189    ) -> Task<Result<AgentModelInfo>>;
190
191    /// Whenever the model list is updated the receiver will be notified.
192    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
193}
194
195#[derive(Debug, Clone, PartialEq, Eq, Hash)]
196pub struct AgentModelId(pub SharedString);
197
198impl std::ops::Deref for AgentModelId {
199    type Target = SharedString;
200
201    fn deref(&self) -> &Self::Target {
202        &self.0
203    }
204}
205
206impl fmt::Display for AgentModelId {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        self.0.fmt(f)
209    }
210}
211
212#[derive(Debug, Clone, PartialEq, Eq)]
213pub struct AgentModelInfo {
214    pub id: AgentModelId,
215    pub name: SharedString,
216    pub icon: Option<IconName>,
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Hash)]
220pub struct AgentModelGroupName(pub SharedString);
221
222#[derive(Debug, Clone)]
223pub enum AgentModelList {
224    Flat(Vec<AgentModelInfo>),
225    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
226}
227
228impl AgentModelList {
229    pub fn is_empty(&self) -> bool {
230        match self {
231            AgentModelList::Flat(models) => models.is_empty(),
232            AgentModelList::Grouped(groups) => groups.is_empty(),
233        }
234    }
235}
236
237#[cfg(feature = "test-support")]
238mod test_support {
239    use std::sync::Arc;
240
241    use action_log::ActionLog;
242    use collections::HashMap;
243    use futures::{channel::oneshot, future::try_join_all};
244    use gpui::{AppContext as _, WeakEntity};
245    use parking_lot::Mutex;
246
247    use super::*;
248
249    #[derive(Clone, Default)]
250    pub struct StubAgentConnection {
251        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
252        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
253        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
254    }
255
256    struct Session {
257        thread: WeakEntity<AcpThread>,
258        response_tx: Option<oneshot::Sender<acp::StopReason>>,
259    }
260
261    impl StubAgentConnection {
262        pub fn new() -> Self {
263            Self {
264                next_prompt_updates: Default::default(),
265                permission_requests: HashMap::default(),
266                sessions: Arc::default(),
267            }
268        }
269
270        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
271            *self.next_prompt_updates.lock() = updates;
272        }
273
274        pub fn with_permission_requests(
275            mut self,
276            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
277        ) -> Self {
278            self.permission_requests = permission_requests;
279            self
280        }
281
282        pub fn send_update(
283            &self,
284            session_id: acp::SessionId,
285            update: acp::SessionUpdate,
286            cx: &mut App,
287        ) {
288            assert!(
289                self.next_prompt_updates.lock().is_empty(),
290                "Use either send_update or set_next_prompt_updates"
291            );
292
293            self.sessions
294                .lock()
295                .get(&session_id)
296                .unwrap()
297                .thread
298                .update(cx, |thread, cx| {
299                    thread.handle_session_update(update, cx).unwrap();
300                })
301                .unwrap();
302        }
303
304        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
305            self.sessions
306                .lock()
307                .get_mut(&session_id)
308                .unwrap()
309                .response_tx
310                .take()
311                .expect("No pending turn")
312                .send(stop_reason)
313                .unwrap();
314        }
315    }
316
317    impl AgentConnection for StubAgentConnection {
318        fn auth_methods(&self) -> &[acp::AuthMethod] {
319            &[]
320        }
321
322        fn new_thread(
323            self: Rc<Self>,
324            project: Entity<Project>,
325            _cwd: &Path,
326            cx: &mut gpui::App,
327        ) -> Task<gpui::Result<Entity<AcpThread>>> {
328            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
329            let action_log = cx.new(|_| ActionLog::new(project.clone()));
330            let thread = cx.new(|cx| {
331                AcpThread::new(
332                    "Test",
333                    self.clone(),
334                    project,
335                    action_log,
336                    session_id.clone(),
337                    watch::Receiver::constant(acp::PromptCapabilities {
338                        image: true,
339                        audio: true,
340                        embedded_context: true,
341                    }),
342                    cx,
343                )
344            });
345            self.sessions.lock().insert(
346                session_id,
347                Session {
348                    thread: thread.downgrade(),
349                    response_tx: None,
350                },
351            );
352            Task::ready(Ok(thread))
353        }
354
355        fn authenticate(
356            &self,
357            _method_id: acp::AuthMethodId,
358            _cx: &mut App,
359        ) -> Task<gpui::Result<()>> {
360            unimplemented!()
361        }
362
363        fn prompt(
364            &self,
365            _id: Option<UserMessageId>,
366            params: acp::PromptRequest,
367            cx: &mut App,
368        ) -> Task<gpui::Result<acp::PromptResponse>> {
369            let mut sessions = self.sessions.lock();
370            let Session {
371                thread,
372                response_tx,
373            } = sessions.get_mut(&params.session_id).unwrap();
374            let mut tasks = vec![];
375            if self.next_prompt_updates.lock().is_empty() {
376                let (tx, rx) = oneshot::channel();
377                response_tx.replace(tx);
378                cx.spawn(async move |_| {
379                    let stop_reason = rx.await?;
380                    Ok(acp::PromptResponse { stop_reason })
381                })
382            } else {
383                for update in self.next_prompt_updates.lock().drain(..) {
384                    let thread = thread.clone();
385                    let update = update.clone();
386                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
387                        &update
388                        && let Some(options) = self.permission_requests.get(&tool_call.id)
389                    {
390                        Some((tool_call.clone(), options.clone()))
391                    } else {
392                        None
393                    };
394                    let task = cx.spawn(async move |cx| {
395                        if let Some((tool_call, options)) = permission_request {
396                            thread
397                                .update(cx, |thread, cx| {
398                                    thread.request_tool_call_authorization(
399                                        tool_call.clone().into(),
400                                        options.clone(),
401                                        cx,
402                                    )
403                                })??
404                                .await;
405                        }
406                        thread.update(cx, |thread, cx| {
407                            thread.handle_session_update(update.clone(), cx).unwrap();
408                        })?;
409                        anyhow::Ok(())
410                    });
411                    tasks.push(task);
412                }
413
414                cx.spawn(async move |_| {
415                    try_join_all(tasks).await?;
416                    Ok(acp::PromptResponse {
417                        stop_reason: acp::StopReason::EndTurn,
418                    })
419                })
420            }
421        }
422
423        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
424            if let Some(end_turn_tx) = self
425                .sessions
426                .lock()
427                .get_mut(session_id)
428                .unwrap()
429                .response_tx
430                .take()
431            {
432                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
433            }
434        }
435
436        fn truncate(
437            &self,
438            _session_id: &agent_client_protocol::SessionId,
439            _cx: &App,
440        ) -> Option<Rc<dyn AgentSessionTruncate>> {
441            Some(Rc::new(StubAgentSessionEditor))
442        }
443
444        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
445            self
446        }
447    }
448
449    struct StubAgentSessionEditor;
450
451    impl AgentSessionTruncate for StubAgentSessionEditor {
452        fn run(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
453            Task::ready(Ok(()))
454        }
455    }
456}
457
458#[cfg(feature = "test-support")]
459pub use test_support::*;