connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn auth_methods(&self) -> &[acp::AuthMethod];
 31
 32    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 33
 34    fn prompt(
 35        &self,
 36        user_message_id: Option<UserMessageId>,
 37        params: acp::PromptRequest,
 38        cx: &mut App,
 39    ) -> Task<Result<acp::PromptResponse>>;
 40
 41    fn prompt_capabilities(&self) -> acp::PromptCapabilities;
 42
 43    fn resume(
 44        &self,
 45        _session_id: &acp::SessionId,
 46        _cx: &mut App,
 47    ) -> Option<Rc<dyn AgentSessionResume>> {
 48        None
 49    }
 50
 51    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 52
 53    fn session_editor(
 54        &self,
 55        _session_id: &acp::SessionId,
 56        _cx: &mut App,
 57    ) -> Option<Rc<dyn AgentSessionEditor>> {
 58        None
 59    }
 60
 61    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 62    ///
 63    /// If the agent does not support model selection, returns [None].
 64    /// This allows sharing the selector in UI components.
 65    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 66        None
 67    }
 68
 69    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 70        None
 71    }
 72
 73    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 74}
 75
 76impl dyn AgentConnection {
 77    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 78        self.into_any().downcast().ok()
 79    }
 80}
 81
 82pub trait AgentSessionEditor {
 83    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 84}
 85
 86pub trait AgentSessionResume {
 87    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 88}
 89
 90pub trait AgentTelemetry {
 91    /// The name of the agent used for telemetry.
 92    fn agent_name(&self) -> String;
 93
 94    /// A representation of the current thread state that can be serialized for
 95    /// storage with telemetry events.
 96    fn thread_data(
 97        &self,
 98        session_id: &acp::SessionId,
 99        cx: &mut App,
100    ) -> Task<Result<serde_json::Value>>;
101}
102
103#[derive(Debug)]
104pub struct AuthRequired {
105    pub description: Option<String>,
106    pub provider_id: Option<LanguageModelProviderId>,
107}
108
109impl AuthRequired {
110    pub fn new() -> Self {
111        Self {
112            description: None,
113            provider_id: None,
114        }
115    }
116
117    pub fn with_description(mut self, description: String) -> Self {
118        self.description = Some(description);
119        self
120    }
121
122    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
123        self.provider_id = Some(provider_id);
124        self
125    }
126}
127
128impl Error for AuthRequired {}
129impl fmt::Display for AuthRequired {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        write!(f, "Authentication required")
132    }
133}
134
135/// Trait for agents that support listing, selecting, and querying language models.
136///
137/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
138pub trait AgentModelSelector: 'static {
139    /// Lists all available language models for this agent.
140    ///
141    /// # Parameters
142    /// - `cx`: The GPUI app context for async operations and global access.
143    ///
144    /// # Returns
145    /// A task resolving to the list of models or an error (e.g., if no models are configured).
146    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
147
148    /// Selects a model for a specific session (thread).
149    ///
150    /// This sets the default model for future interactions in the session.
151    /// If the session doesn't exist or the model is invalid, it returns an error.
152    ///
153    /// # Parameters
154    /// - `session_id`: The ID of the session (thread) to apply the model to.
155    /// - `model`: The model to select (should be one from [list_models]).
156    /// - `cx`: The GPUI app context.
157    ///
158    /// # Returns
159    /// A task resolving to `Ok(())` on success or an error.
160    fn select_model(
161        &self,
162        session_id: acp::SessionId,
163        model_id: AgentModelId,
164        cx: &mut App,
165    ) -> Task<Result<()>>;
166
167    /// Retrieves the currently selected model for a specific session (thread).
168    ///
169    /// # Parameters
170    /// - `session_id`: The ID of the session (thread) to query.
171    /// - `cx`: The GPUI app context.
172    ///
173    /// # Returns
174    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
175    fn selected_model(
176        &self,
177        session_id: &acp::SessionId,
178        cx: &mut App,
179    ) -> Task<Result<AgentModelInfo>>;
180
181    /// Whenever the model list is updated the receiver will be notified.
182    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, Hash)]
186pub struct AgentModelId(pub SharedString);
187
188impl std::ops::Deref for AgentModelId {
189    type Target = SharedString;
190
191    fn deref(&self) -> &Self::Target {
192        &self.0
193    }
194}
195
196impl fmt::Display for AgentModelId {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        self.0.fmt(f)
199    }
200}
201
202#[derive(Debug, Clone, PartialEq, Eq)]
203pub struct AgentModelInfo {
204    pub id: AgentModelId,
205    pub name: SharedString,
206    pub icon: Option<IconName>,
207}
208
209#[derive(Debug, Clone, PartialEq, Eq, Hash)]
210pub struct AgentModelGroupName(pub SharedString);
211
212#[derive(Debug, Clone)]
213pub enum AgentModelList {
214    Flat(Vec<AgentModelInfo>),
215    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
216}
217
218impl AgentModelList {
219    pub fn is_empty(&self) -> bool {
220        match self {
221            AgentModelList::Flat(models) => models.is_empty(),
222            AgentModelList::Grouped(groups) => groups.is_empty(),
223        }
224    }
225}
226
227#[cfg(feature = "test-support")]
228mod test_support {
229    use std::sync::Arc;
230
231    use action_log::ActionLog;
232    use collections::HashMap;
233    use futures::{channel::oneshot, future::try_join_all};
234    use gpui::{AppContext as _, WeakEntity};
235    use parking_lot::Mutex;
236
237    use super::*;
238
239    #[derive(Clone, Default)]
240    pub struct StubAgentConnection {
241        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
242        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
243        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
244    }
245
246    struct Session {
247        thread: WeakEntity<AcpThread>,
248        response_tx: Option<oneshot::Sender<acp::StopReason>>,
249    }
250
251    impl StubAgentConnection {
252        pub fn new() -> Self {
253            Self {
254                next_prompt_updates: Default::default(),
255                permission_requests: HashMap::default(),
256                sessions: Arc::default(),
257            }
258        }
259
260        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
261            *self.next_prompt_updates.lock() = updates;
262        }
263
264        pub fn with_permission_requests(
265            mut self,
266            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
267        ) -> Self {
268            self.permission_requests = permission_requests;
269            self
270        }
271
272        pub fn send_update(
273            &self,
274            session_id: acp::SessionId,
275            update: acp::SessionUpdate,
276            cx: &mut App,
277        ) {
278            assert!(
279                self.next_prompt_updates.lock().is_empty(),
280                "Use either send_update or set_next_prompt_updates"
281            );
282
283            self.sessions
284                .lock()
285                .get(&session_id)
286                .unwrap()
287                .thread
288                .update(cx, |thread, cx| {
289                    thread.handle_session_update(update, cx).unwrap();
290                })
291                .unwrap();
292        }
293
294        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
295            self.sessions
296                .lock()
297                .get_mut(&session_id)
298                .unwrap()
299                .response_tx
300                .take()
301                .expect("No pending turn")
302                .send(stop_reason)
303                .unwrap();
304        }
305    }
306
307    impl AgentConnection for StubAgentConnection {
308        fn auth_methods(&self) -> &[acp::AuthMethod] {
309            &[]
310        }
311
312        fn new_thread(
313            self: Rc<Self>,
314            project: Entity<Project>,
315            _cwd: &Path,
316            cx: &mut gpui::App,
317        ) -> Task<gpui::Result<Entity<AcpThread>>> {
318            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
319            let action_log = cx.new(|_| ActionLog::new(project.clone()));
320            let thread = cx.new(|_cx| {
321                AcpThread::new(
322                    "Test",
323                    self.clone(),
324                    project,
325                    action_log,
326                    session_id.clone(),
327                )
328            });
329            self.sessions.lock().insert(
330                session_id,
331                Session {
332                    thread: thread.downgrade(),
333                    response_tx: None,
334                },
335            );
336            Task::ready(Ok(thread))
337        }
338
339        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
340            acp::PromptCapabilities {
341                image: true,
342                audio: true,
343                embedded_context: true,
344            }
345        }
346
347        fn authenticate(
348            &self,
349            _method_id: acp::AuthMethodId,
350            _cx: &mut App,
351        ) -> Task<gpui::Result<()>> {
352            unimplemented!()
353        }
354
355        fn prompt(
356            &self,
357            _id: Option<UserMessageId>,
358            params: acp::PromptRequest,
359            cx: &mut App,
360        ) -> Task<gpui::Result<acp::PromptResponse>> {
361            let mut sessions = self.sessions.lock();
362            let Session {
363                thread,
364                response_tx,
365            } = sessions.get_mut(&params.session_id).unwrap();
366            let mut tasks = vec![];
367            if self.next_prompt_updates.lock().is_empty() {
368                let (tx, rx) = oneshot::channel();
369                response_tx.replace(tx);
370                cx.spawn(async move |_| {
371                    let stop_reason = rx.await?;
372                    Ok(acp::PromptResponse { stop_reason })
373                })
374            } else {
375                for update in self.next_prompt_updates.lock().drain(..) {
376                    let thread = thread.clone();
377                    let update = update.clone();
378                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
379                        &update
380                        && let Some(options) = self.permission_requests.get(&tool_call.id)
381                    {
382                        Some((tool_call.clone(), options.clone()))
383                    } else {
384                        None
385                    };
386                    let task = cx.spawn(async move |cx| {
387                        if let Some((tool_call, options)) = permission_request {
388                            let permission = thread.update(cx, |thread, cx| {
389                                thread.request_tool_call_authorization(
390                                    tool_call.clone().into(),
391                                    options.clone(),
392                                    cx,
393                                )
394                            })?;
395                            permission?.await?;
396                        }
397                        thread.update(cx, |thread, cx| {
398                            thread.handle_session_update(update.clone(), cx).unwrap();
399                        })?;
400                        anyhow::Ok(())
401                    });
402                    tasks.push(task);
403                }
404
405                cx.spawn(async move |_| {
406                    try_join_all(tasks).await?;
407                    Ok(acp::PromptResponse {
408                        stop_reason: acp::StopReason::EndTurn,
409                    })
410                })
411            }
412        }
413
414        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
415            if let Some(end_turn_tx) = self
416                .sessions
417                .lock()
418                .get_mut(session_id)
419                .unwrap()
420                .response_tx
421                .take()
422            {
423                end_turn_tx.send(acp::StopReason::Cancelled).unwrap();
424            }
425        }
426
427        fn session_editor(
428            &self,
429            _session_id: &agent_client_protocol::SessionId,
430            _cx: &mut App,
431        ) -> Option<Rc<dyn AgentSessionEditor>> {
432            Some(Rc::new(StubAgentSessionEditor))
433        }
434
435        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
436            self
437        }
438    }
439
440    struct StubAgentSessionEditor;
441
442    impl AgentSessionEditor for StubAgentSessionEditor {
443        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
444            Task::ready(Ok(()))
445        }
446    }
447}
448
449#[cfg(feature = "test-support")]
450pub use test_support::*;