1use std::{error::Error, fmt, path::Path, rc::Rc};
2
3use agent_client_protocol::{self as acp};
4use anyhow::Result;
5use collections::IndexMap;
6use gpui::{AsyncApp, Entity, SharedString, Task};
7use project::Project;
8use ui::{App, IconName};
9
10use crate::AcpThread;
11
12pub trait AgentConnection {
13 fn new_thread(
14 self: Rc<Self>,
15 project: Entity<Project>,
16 cwd: &Path,
17 cx: &mut AsyncApp,
18 ) -> Task<Result<Entity<AcpThread>>>;
19
20 fn auth_methods(&self) -> &[acp::AuthMethod];
21
22 fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
23
24 fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
25 -> Task<Result<acp::PromptResponse>>;
26
27 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
28
29 /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
30 ///
31 /// If the agent does not support model selection, returns [None].
32 /// This allows sharing the selector in UI components.
33 fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
34 None
35 }
36}
37
38#[derive(Debug)]
39pub struct AuthRequired;
40
41impl Error for AuthRequired {}
42impl fmt::Display for AuthRequired {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 write!(f, "AuthRequired")
45 }
46}
47
48/// Trait for agents that support listing, selecting, and querying language models.
49///
50/// This is an optional capability; agents indicate support via [AgentConnection::model_selector].
51pub trait AgentModelSelector: 'static {
52 /// Lists all available language models for this agent.
53 ///
54 /// # Parameters
55 /// - `cx`: The GPUI app context for async operations and global access.
56 ///
57 /// # Returns
58 /// A task resolving to the list of models or an error (e.g., if no models are configured).
59 fn list_models(&self, cx: &mut App) -> Task<Result<AgentModelList>>;
60
61 /// Selects a model for a specific session (thread).
62 ///
63 /// This sets the default model for future interactions in the session.
64 /// If the session doesn't exist or the model is invalid, it returns an error.
65 ///
66 /// # Parameters
67 /// - `session_id`: The ID of the session (thread) to apply the model to.
68 /// - `model`: The model to select (should be one from [list_models]).
69 /// - `cx`: The GPUI app context.
70 ///
71 /// # Returns
72 /// A task resolving to `Ok(())` on success or an error.
73 fn select_model(
74 &self,
75 session_id: acp::SessionId,
76 model_id: AgentModelId,
77 cx: &mut App,
78 ) -> Task<Result<()>>;
79
80 /// Retrieves the currently selected model for a specific session (thread).
81 ///
82 /// # Parameters
83 /// - `session_id`: The ID of the session (thread) to query.
84 /// - `cx`: The GPUI app context.
85 ///
86 /// # Returns
87 /// A task resolving to the selected model (always set) or an error (e.g., session not found).
88 fn selected_model(
89 &self,
90 session_id: &acp::SessionId,
91 cx: &mut App,
92 ) -> Task<Result<AgentModelInfo>>;
93
94 /// Whenever the model list is updated the receiver will be notified.
95 fn watch(&self, cx: &mut App) -> watch::Receiver<()>;
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, Hash)]
99pub struct AgentModelId(pub SharedString);
100
101impl std::ops::Deref for AgentModelId {
102 type Target = SharedString;
103
104 fn deref(&self) -> &Self::Target {
105 &self.0
106 }
107}
108
109impl fmt::Display for AgentModelId {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 self.0.fmt(f)
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub struct AgentModelInfo {
117 pub id: AgentModelId,
118 pub name: SharedString,
119 pub icon: Option<IconName>,
120}
121
122#[derive(Debug, Clone, PartialEq, Eq, Hash)]
123pub struct AgentModelGroupName(pub SharedString);
124
125#[derive(Debug, Clone)]
126pub enum AgentModelList {
127 Flat(Vec<AgentModelInfo>),
128 Grouped(IndexMap<AgentModelGroupName, Vec<AgentModelInfo>>),
129}
130
131impl AgentModelList {
132 pub fn is_empty(&self) -> bool {
133 match self {
134 AgentModelList::Flat(models) => models.is_empty(),
135 AgentModelList::Grouped(groups) => groups.is_empty(),
136 }
137 }
138}