connection.rs

  1use std::{error::Error, fmt, path::Path, rc::Rc};
  2
  3use agent_client_protocol::{self as acp};
  4use anyhow::Result;
  5use collections::IndexMap;
  6use gpui::{AsyncApp, Entity, SharedString, Task};
  7use project::Project;
  8use ui::{App, IconName};
  9
 10use crate::AcpThread;
 11
 12pub trait AgentConnection {
 13    fn new_thread(
 14        self: Rc<Self>,
 15        project: Entity<Project>,
 16        cwd: &Path,
 17        cx: &mut AsyncApp,
 18    ) -> Task<Result<Entity<AcpThread>>>;
 19
 20    fn auth_methods(&self) -> &[acp::AuthMethod];
 21
 22    fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 23
 24    fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
 25    -> Task<Result<acp::PromptResponse>>;
 26
 27    fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 28
 29    /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
 30    ///
 31    /// If the agent does not support model selection, returns [None].
 32    /// This allows sharing the selector in UI components.
 33    fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
 34        None
 35    }
 36}
 37
 38#[derive(Debug)]
 39pub struct AuthRequired;
 40
 41impl Error for AuthRequired {}
 42impl fmt::Display for AuthRequired {
 43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 44        write!(f, "AuthRequired")
 45    }
 46}
 47
 48/// Trait for agents that support listing, selecting, and querying language models.
 49///
 50/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
 51pub trait AgentModelSelector: 'static {
 52    /// Lists all available language models for this agent.
 53    ///
 54    /// # Parameters
 55    /// - `cx`: The GPUI app context for async operations and global access.
 56    ///
 57    /// # Returns
 58    /// A task resolving to the list of models or an error (e.g., if no models are configured).
 59    fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
 60
 61    /// Selects a model for a specific session (thread).
 62    ///
 63    /// This sets the default model for future interactions in the session.
 64    /// If the session doesn't exist or the model is invalid, it returns an error.
 65    ///
 66    /// # Parameters
 67    /// - `session_id`: The ID of the session (thread) to apply the model to.
 68    /// - `model`: The model to select (should be one from [list_models]).
 69    /// - `cx`: The GPUI app context.
 70    ///
 71    /// # Returns
 72    /// A task resolving to `Ok(())` on success or an error.
 73    fn select_model(
 74        &self,
 75        session_id: acp::SessionId,
 76        model_id: AgentModelId,
 77        cx: &mut App,
 78    ) -> Task<Result<()>>;
 79
 80    /// Retrieves the currently selected model for a specific session (thread).
 81    ///
 82    /// # Parameters
 83    /// - `session_id`: The ID of the session (thread) to query.
 84    /// - `cx`: The GPUI app context.
 85    ///
 86    /// # Returns
 87    /// A task resolving to the selected model (always set) or an error (e.g., session not found).
 88    fn selected_model(
 89        &self,
 90        session_id: &acp::SessionId,
 91        cx: &mut App,
 92    ) -> Task<Result<AgentModelInfo>>;
 93
 94    /// Whenever the model list is updated the receiver will be notified.
 95    fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
 96}
 97
 98#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 99pub struct AgentModelId(pub SharedString);
100
101impl std::ops::Deref for AgentModelId {
102    type Target = SharedString;
103
104    fn deref(&self) -> &Self::Target {
105        &self.0
106    }
107}
108
109impl fmt::Display for AgentModelId {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        self.0.fmt(f)
112    }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub struct AgentModelInfo {
117    pub id: AgentModelId,
118    pub name: SharedString,
119    pub icon: Option<IconName>,
120}
121
122#[derive(Debug, Clone, PartialEq, Eq, Hash)]
123pub struct AgentModelGroupName(pub SharedString);
124
125#[derive(Debug, Clone)]
126pub enum AgentModelList {
127    Flat(Vec<AgentModelInfo>),
128    Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
129}
130
131impl AgentModelList {
132    pub fn is_empty(&self) -> bool {
133        match self {
134            AgentModelList::Flat(models) => models.is_empty(),
135            AgentModelList::Grouped(groups) => groups.is_empty(),
136        }
137    }
138}