1use std::rc::Rc;
2
3use acp_thread::{AgentConnection, LoadError};
4use agent_client_protocol as acp;
5use agent_servers::{AgentServer, AgentServerDelegate};
6use anyhow::Result;
7use collections::HashMap;
8use futures::{FutureExt, future::Shared};
9use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
10
11use project::{AgentServerStore, AgentServersUpdated, Project};
12use watch::Receiver;
13
14use crate::{Agent, ThreadHistory};
15
16pub enum AgentConnectionEntry {
17 Connecting {
18 connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
19 },
20 Connected(AgentConnectedState),
21 Error {
22 error: LoadError,
23 },
24}
25
26#[derive(Clone)]
27pub struct AgentConnectedState {
28 pub connection: Rc<dyn AgentConnection>,
29 pub history: Option<Entity<ThreadHistory>>,
30}
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33pub enum AgentConnectionStatus {
34 Disconnected,
35 Connecting,
36 Connected,
37}
38
39impl AgentConnectionEntry {
40 pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
41 match self {
42 AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
43 AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
44 AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
45 }
46 }
47
48 pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
49 match self {
50 AgentConnectionEntry::Connected(state) => state.history.as_ref(),
51 _ => None,
52 }
53 }
54
55 pub fn status(&self) -> AgentConnectionStatus {
56 match self {
57 AgentConnectionEntry::Connecting { .. } => AgentConnectionStatus::Connecting,
58 AgentConnectionEntry::Connected(_) => AgentConnectionStatus::Connected,
59 AgentConnectionEntry::Error { .. } => AgentConnectionStatus::Disconnected,
60 }
61 }
62}
63
64pub enum AgentConnectionEntryEvent {
65 NewVersionAvailable(SharedString),
66}
67
68impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
69
70#[derive(Clone)]
71pub struct ActiveAcpConnection {
72 pub agent_id: project::AgentId,
73 pub connection: Rc<acp::ClientSideConnection>,
74}
75
76pub struct AgentConnectionStore {
77 project: Entity<Project>,
78 entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
79 _subscriptions: Vec<Subscription>,
80}
81
82impl AgentConnectionStore {
83 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
84 let agent_server_store = project.read(cx).agent_server_store().clone();
85 let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
86 Self {
87 project,
88 entries: HashMap::default(),
89 _subscriptions: vec![subscription],
90 }
91 }
92
93 pub fn project(&self) -> &Entity<Project> {
94 &self.project
95 }
96
97 pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
98 self.entries.get(key)
99 }
100
101 pub fn connection_status(&self, key: &Agent, cx: &App) -> AgentConnectionStatus {
102 self.entries
103 .get(key)
104 .map(|entry| entry.read(cx).status())
105 .unwrap_or(AgentConnectionStatus::Disconnected)
106 }
107
108 pub fn active_acp_connections(&self, cx: &App) -> Vec<ActiveAcpConnection> {
109 self.entries
110 .values()
111 .filter_map(|entry| match entry.read(cx) {
112 AgentConnectionEntry::Connected(state) => state
113 .connection
114 .clone()
115 .downcast::<agent_servers::AcpConnection>()
116 .map(|connection| ActiveAcpConnection {
117 agent_id: state.connection.agent_id(),
118 connection: connection.client_connection().clone(),
119 }),
120 AgentConnectionEntry::Connecting { .. } | AgentConnectionEntry::Error { .. } => {
121 None
122 }
123 })
124 .collect()
125 }
126
127 pub fn restart_connection(
128 &mut self,
129 key: Agent,
130 server: Rc<dyn AgentServer>,
131 cx: &mut Context<Self>,
132 ) -> Entity<AgentConnectionEntry> {
133 if let Some(entry) = self.entries.get(&key) {
134 if matches!(entry.read(cx), AgentConnectionEntry::Connecting { .. }) {
135 return entry.clone();
136 }
137 }
138
139 self.entries.remove(&key);
140 self.request_connection(key, server, cx)
141 }
142
143 pub fn request_connection(
144 &mut self,
145 key: Agent,
146 server: Rc<dyn AgentServer>,
147 cx: &mut Context<Self>,
148 ) -> Entity<AgentConnectionEntry> {
149 if let Some(entry) = self.entries.get(&key) {
150 return entry.clone();
151 }
152
153 let (mut new_version_rx, connect_task) = self.start_connection(server, cx);
154 let connect_task = connect_task.shared();
155
156 let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
157 connect_task: connect_task.clone(),
158 });
159
160 self.entries.insert(key.clone(), entry.clone());
161 cx.notify();
162
163 cx.spawn({
164 let key = key.clone();
165 let entry = entry.downgrade();
166 async move |this, cx| match connect_task.await {
167 Ok(connected_state) => {
168 this.update(cx, move |this, cx| {
169 if this.entries.get(&key) != entry.upgrade().as_ref() {
170 return;
171 }
172
173 entry
174 .update(cx, move |entry, cx| {
175 if let AgentConnectionEntry::Connecting { .. } = entry {
176 *entry = AgentConnectionEntry::Connected(connected_state);
177 cx.notify();
178 }
179 })
180 .ok();
181 cx.notify();
182 })
183 .ok();
184 }
185 Err(error) => {
186 this.update(cx, move |this, cx| {
187 if this.entries.get(&key) != entry.upgrade().as_ref() {
188 return;
189 }
190
191 entry
192 .update(cx, move |entry, cx| {
193 if let AgentConnectionEntry::Connecting { .. } = entry {
194 *entry = AgentConnectionEntry::Error { error };
195 cx.notify();
196 }
197 })
198 .ok();
199 this.entries.remove(&key);
200 cx.notify();
201 })
202 .ok();
203 }
204 }
205 })
206 .detach();
207
208 cx.spawn({
209 let entry = entry.downgrade();
210 async move |this, cx| {
211 while let Ok(version) = new_version_rx.recv().await {
212 let Some(version) = version else {
213 continue;
214 };
215
216 this.update(cx, move |this, cx| {
217 if this.entries.get(&key) != entry.upgrade().as_ref() {
218 return;
219 }
220
221 entry
222 .update(cx, move |_entry, cx| {
223 cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
224 version.into(),
225 ));
226 })
227 .ok();
228 this.entries.remove(&key);
229 cx.notify();
230 })
231 .ok();
232 break;
233 }
234 }
235 })
236 .detach();
237
238 entry
239 }
240
241 fn handle_agent_servers_updated(
242 &mut self,
243 store: Entity<AgentServerStore>,
244 _: &AgentServersUpdated,
245 cx: &mut Context<Self>,
246 ) {
247 let external_agent_ids: Vec<_> = {
248 let store = store.read(cx);
249 store.external_agents.keys().cloned().collect()
250 };
251
252 self.entries.retain(|key, _| match key {
253 Agent::NativeAgent => true,
254 Agent::Custom { id } => external_agent_ids.contains(id),
255 });
256 cx.notify();
257 }
258
259 fn start_connection(
260 &self,
261 server: Rc<dyn AgentServer>,
262 cx: &mut Context<Self>,
263 ) -> (
264 Receiver<Option<String>>,
265 Task<Result<AgentConnectedState, LoadError>>,
266 ) {
267 let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
268
269 let agent_server_store = self.project.read(cx).agent_server_store().clone();
270 let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
271
272 let connect_task = server.connect(delegate, self.project.clone(), cx);
273 let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
274 Ok(connection) => cx.update(|cx| {
275 let history = connection
276 .session_list(cx)
277 .map(|session_list| cx.new(|cx| ThreadHistory::new(session_list, cx)));
278 Ok(AgentConnectedState {
279 connection,
280 history,
281 })
282 }),
283 Err(err) => match err.downcast::<LoadError>() {
284 Ok(load_error) => Err(load_error),
285 Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
286 },
287 });
288 (new_version_rx, connect_task)
289 }
290}