connection.rs

  1use crate::AcpThread;
  2use agent_client_protocol::{self as acp};
  3use anyhow::Result;
  4use collections::IndexMap;
  5use gpui::{Entity, SharedString, Task};
  6use language_model::LanguageModelProviderId;
  7use project::Project;
  8use serde::{Deserialize, Serialize};
  9use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 10use ui::{App, IconName};
 11use uuid::Uuid;
 12
 13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
 14pub struct UserMessageId(Arc<str>);
 15
 16impl UserMessageId {
 17    pub fn new() -> Self {
 18        Self(Uuid::new_v4().to_string().into())
 19    }
 20}
 21
 22pub trait AgentConnection {
 23    fn new_thread(
 24        self: Rc<Self>,
 25        project: Entity<Project>,
 26        cwd: &Path,
 27        cx: &mut App,
 28    ) -> Task<Result<Entity<AcpThread>>>;
 29
 30    fn auth_methods(&self) -> &[acp::AuthMethod];
 31
 32    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 33
 34    fn prompt(
 35        &self,
 36        user_message_id: Option<UserMessageId>,
 37        params: acp::PromptRequest,
 38        cx: &mut App,
 39    ) -> Task<Result<acp::PromptResponse>>;
 40
 41    fn resume(
 42        &self,
 43        _session_id: &acp::SessionId,
 44        _cx: &mut App,
 45    ) -> Option<Rc<dyn AgentSessionResume>> {
 46        None
 47    }
 48
 49    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 50
 51    fn session_editor(
 52        &self,
 53        _session_id: &acp::SessionId,
 54        _cx: &mut App,
 55    ) -> Option<Rc<dyn AgentSessionEditor>> {
 56        None
 57    }
 58
 59    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 60    ///
 61    /// If the agent does not support model selection, returns [None].
 62    /// This allows sharing the selector in UI components.
 63    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 64        None
 65    }
 66
 67    fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
 68        None
 69    }
 70
 71    fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
 72}
 73
 74impl dyn AgentConnection {
 75    pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
 76        self.into_any().downcast().ok()
 77    }
 78}
 79
 80pub trait AgentSessionEditor {
 81    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
 82}
 83
 84pub trait AgentSessionResume {
 85    fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
 86}
 87
 88pub trait AgentTelemetry {
 89    /// The name of the agent used for telemetry.
 90    fn agent_name(&self) -> String;
 91
 92    /// A representation of the current thread state that can be serialized for
 93    /// storage with telemetry events.
 94    fn thread_data(
 95        &self,
 96        session_id: &acp::SessionId,
 97        cx: &mut App,
 98    ) -> Task<Result<serde_json::Value>>;
 99}
100
101#[derive(Debug)]
102pub struct AuthRequired {
103    pub description: Option<String>,
104    pub provider_id: Option<LanguageModelProviderId>,
105}
106
107impl AuthRequired {
108    pub fn new() -> Self {
109        Self {
110            description: None,
111            provider_id: None,
112        }
113    }
114
115    pub fn with_description(mut self, description: String) -> Self {
116        self.description = Some(description);
117        self
118    }
119
120    pub fn with_language_model_provider(mut self, provider_id: LanguageModelProviderId) -> Self {
121        self.provider_id = Some(provider_id);
122        self
123    }
124}
125
126impl Error for AuthRequired {}
127impl fmt::Display for AuthRequired {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        write!(f, "Authentication required")
130    }
131}
132
133/// Trait for agents that support listing, selecting, and querying language models.
134///
135/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
136pub trait AgentModelSelector: 'static {
137    /// Lists all available language models for this agent.
138    ///
139    /// # Parameters
140    /// - `cx`: The GPUI app context for async operations and global access.
141    ///
142    /// # Returns
143    /// A task resolving to the list of models or an error (e.g., if no models are configured).
144    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
145
146    /// Selects a model for a specific session (thread).
147    ///
148    /// This sets the default model for future interactions in the session.
149    /// If the session doesn't exist or the model is invalid, it returns an error.
150    ///
151    /// # Parameters
152    /// - `session_id`: The ID of the session (thread) to apply the model to.
153    /// - `model`: The model to select (should be one from [list_models]).
154    /// - `cx`: The GPUI app context.
155    ///
156    /// # Returns
157    /// A task resolving to `Ok(())` on success or an error.
158    fn select_model(
159        &self,
160        session_id: acp::SessionId,
161        model_id: AgentModelId,
162        cx: &mut App,
163    ) -> Task<Result<()>>;
164
165    /// Retrieves the currently selected model for a specific session (thread).
166    ///
167    /// # Parameters
168    /// - `session_id`: The ID of the session (thread) to query.
169    /// - `cx`: The GPUI app context.
170    ///
171    /// # Returns
172    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
173    fn selected_model(
174        &self,
175        session_id: &acp::SessionId,
176        cx: &mut App,
177    ) -> Task<Result<AgentModelInfo>>;
178
179    /// Whenever the model list is updated the receiver will be notified.
180    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
181}
182
183#[derive(Debug, Clone, PartialEq, Eq, Hash)]
184pub struct AgentModelId(pub SharedString);
185
186impl std::ops::Deref for AgentModelId {
187    type Target = SharedString;
188
189    fn deref(&self) -> &Self::Target {
190        &self.0
191    }
192}
193
194impl fmt::Display for AgentModelId {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        self.0.fmt(f)
197    }
198}
199
200#[derive(Debug, Clone, PartialEq, Eq)]
201pub struct AgentModelInfo {
202    pub id: AgentModelId,
203    pub name: SharedString,
204    pub icon: Option<IconName>,
205}
206
207#[derive(Debug, Clone, PartialEq, Eq, Hash)]
208pub struct AgentModelGroupName(pub SharedString);
209
210#[derive(Debug, Clone)]
211pub enum AgentModelList {
212    Flat(Vec<AgentModelInfo>),
213    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
214}
215
216impl AgentModelList {
217    pub fn is_empty(&self) -> bool {
218        match self {
219            AgentModelList::Flat(models) => models.is_empty(),
220            AgentModelList::Grouped(groups) => groups.is_empty(),
221        }
222    }
223}
224
225#[cfg(feature = "test-support")]
226mod test_support {
227    use std::sync::Arc;
228
229    use action_log::ActionLog;
230    use collections::HashMap;
231    use futures::{channel::oneshot, future::try_join_all};
232    use gpui::{AppContext as _, WeakEntity};
233    use parking_lot::Mutex;
234
235    use super::*;
236
237    #[derive(Clone, Default)]
238    pub struct StubAgentConnection {
239        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
240        permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
241        next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
242    }
243
244    struct Session {
245        thread: WeakEntity<AcpThread>,
246        response_tx: Option<oneshot::Sender<acp::StopReason>>,
247    }
248
249    impl StubAgentConnection {
250        pub fn new() -> Self {
251            Self {
252                next_prompt_updates: Default::default(),
253                permission_requests: HashMap::default(),
254                sessions: Arc::default(),
255            }
256        }
257
258        pub fn set_next_prompt_updates(&self, updates: Vec<acp::SessionUpdate>) {
259            *self.next_prompt_updates.lock() = updates;
260        }
261
262        pub fn with_permission_requests(
263            mut self,
264            permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
265        ) -> Self {
266            self.permission_requests = permission_requests;
267            self
268        }
269
270        pub fn send_update(
271            &self,
272            session_id: acp::SessionId,
273            update: acp::SessionUpdate,
274            cx: &mut App,
275        ) {
276            assert!(
277                self.next_prompt_updates.lock().is_empty(),
278                "Use either send_update or set_next_prompt_updates"
279            );
280
281            self.sessions
282                .lock()
283                .get(&session_id)
284                .unwrap()
285                .thread
286                .update(cx, |thread, cx| {
287                    thread.handle_session_update(update, cx).unwrap();
288                })
289                .unwrap();
290        }
291
292        pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
293            self.sessions
294                .lock()
295                .get_mut(&session_id)
296                .unwrap()
297                .response_tx
298                .take()
299                .expect("No pending turn")
300                .send(stop_reason)
301                .unwrap();
302        }
303    }
304
305    impl AgentConnection for StubAgentConnection {
306        fn auth_methods(&self) -> &[acp::AuthMethod] {
307            &[]
308        }
309
310        fn new_thread(
311            self: Rc<Self>,
312            project: Entity<Project>,
313            _cwd: &Path,
314            cx: &mut gpui::App,
315        ) -> Task<gpui::Result<Entity<AcpThread>>> {
316            let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
317            let action_log = cx.new(|_| ActionLog::new(project.clone()));
318            let thread = cx.new(|_cx| {
319                AcpThread::new(
320                    "Test",
321                    self.clone(),
322                    project,
323                    action_log,
324                    session_id.clone(),
325                )
326            });
327            self.sessions.lock().insert(
328                session_id,
329                Session {
330                    thread: thread.downgrade(),
331                    response_tx: None,
332                },
333            );
334            Task::ready(Ok(thread))
335        }
336
337        fn authenticate(
338            &self,
339            _method_id: acp::AuthMethodId,
340            _cx: &mut App,
341        ) -> Task<gpui::Result<()>> {
342            unimplemented!()
343        }
344
345        fn prompt(
346            &self,
347            _id: Option<UserMessageId>,
348            params: acp::PromptRequest,
349            cx: &mut App,
350        ) -> Task<gpui::Result<acp::PromptResponse>> {
351            let mut sessions = self.sessions.lock();
352            let Session {
353                thread,
354                response_tx,
355            } = sessions.get_mut(&params.session_id).unwrap();
356            let mut tasks = vec![];
357            if self.next_prompt_updates.lock().is_empty() {
358                let (tx, rx) = oneshot::channel();
359                response_tx.replace(tx);
360                cx.spawn(async move |_| {
361                    let stop_reason = rx.await?;
362                    Ok(acp::PromptResponse { stop_reason })
363                })
364            } else {
365                for update in self.next_prompt_updates.lock().drain(..) {
366                    let thread = thread.clone();
367                    let update = update.clone();
368                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
369                        &update
370                        && let Some(options) = self.permission_requests.get(&tool_call.id)
371                    {
372                        Some((tool_call.clone(), options.clone()))
373                    } else {
374                        None
375                    };
376                    let task = cx.spawn(async move |cx| {
377                        if let Some((tool_call, options)) = permission_request {
378                            let permission = thread.update(cx, |thread, cx| {
379                                thread.request_tool_call_authorization(
380                                    tool_call.clone().into(),
381                                    options.clone(),
382                                    cx,
383                                )
384                            })?;
385                            permission?.await?;
386                        }
387                        thread.update(cx, |thread, cx| {
388                            thread.handle_session_update(update.clone(), cx).unwrap();
389                        })?;
390                        anyhow::Ok(())
391                    });
392                    tasks.push(task);
393                }
394
395                cx.spawn(async move |_| {
396                    try_join_all(tasks).await?;
397                    Ok(acp::PromptResponse {
398                        stop_reason: acp::StopReason::EndTurn,
399                    })
400                })
401            }
402        }
403
404        fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
405            if let Some(end_turn_tx) = self
406                .sessions
407                .lock()
408                .get_mut(session_id)
409                .unwrap()
410                .response_tx
411                .take()
412            {
413                end_turn_tx.send(acp::StopReason::Canceled).unwrap();
414            }
415        }
416
417        fn session_editor(
418            &self,
419            _session_id: &agent_client_protocol::SessionId,
420            _cx: &mut App,
421        ) -> Option<Rc<dyn AgentSessionEditor>> {
422            Some(Rc::new(StubAgentSessionEditor))
423        }
424
425        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
426            self
427        }
428    }
429
430    struct StubAgentSessionEditor;
431
432    impl AgentSessionEditor for StubAgentSessionEditor {
433        fn truncate(&self, _: UserMessageId, _: &mut App) -> Task<Result<()>> {
434            Task::ready(Ok(()))
435        }
436    }
437}
438
439#[cfg(feature = "test-support")]
440pub use test_support::*;