1use std::rc::Rc;
2
3use acp_thread::{AgentConnection, LoadError};
4use agent_servers::AcpConnection;
5use agent_servers::{AgentServer, AgentServerDelegate};
6use anyhow::Result;
7use collections::HashMap;
8use futures::{FutureExt, future::Shared};
9use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
10
11use project::{AgentServerStore, AgentServersUpdated, Project};
12use watch::Receiver;
13
14use crate::Agent;
15
16pub enum AgentConnectionEntry {
17 Connecting {
18 connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
19 },
20 Connected(AgentConnectedState),
21 Error {
22 error: LoadError,
23 },
24}
25
26#[derive(Clone)]
27pub struct AgentConnectedState {
28 pub connection: Rc<dyn AgentConnection>,
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum AgentConnectionStatus {
33 Disconnected,
34 Connecting,
35 Connected,
36}
37
38impl AgentConnectionEntry {
39 pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
40 match self {
41 AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
42 AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
43 AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
44 }
45 }
46
47 pub fn status(&self) -> AgentConnectionStatus {
48 match self {
49 AgentConnectionEntry::Connecting { .. } => AgentConnectionStatus::Connecting,
50 AgentConnectionEntry::Connected(_) => AgentConnectionStatus::Connected,
51 AgentConnectionEntry::Error { .. } => AgentConnectionStatus::Disconnected,
52 }
53 }
54}
55
56pub enum AgentConnectionEntryEvent {
57 NewVersionAvailable(SharedString),
58}
59
60impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
61
62#[derive(Clone)]
63pub struct ActiveAcpConnection {
64 pub agent_id: project::AgentId,
65 pub connection: Rc<AcpConnection>,
66}
67
68pub struct AgentConnectionStore {
69 project: Entity<Project>,
70 entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
71 _subscriptions: Vec<Subscription>,
72}
73
74impl AgentConnectionStore {
75 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
76 let agent_server_store = project.read(cx).agent_server_store().clone();
77 let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
78 Self {
79 project,
80 entries: HashMap::default(),
81 _subscriptions: vec![subscription],
82 }
83 }
84
85 pub fn project(&self) -> &Entity<Project> {
86 &self.project
87 }
88
89 pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
90 self.entries.get(key)
91 }
92
93 pub fn connection_status(&self, key: &Agent, cx: &App) -> AgentConnectionStatus {
94 self.entries
95 .get(key)
96 .map(|entry| entry.read(cx).status())
97 .unwrap_or(AgentConnectionStatus::Disconnected)
98 }
99
100 pub fn active_acp_connections(&self, cx: &App) -> Vec<ActiveAcpConnection> {
101 self.entries
102 .values()
103 .filter_map(|entry| match entry.read(cx) {
104 AgentConnectionEntry::Connected(state) => state
105 .connection
106 .clone()
107 .downcast::<AcpConnection>()
108 .map(|connection| ActiveAcpConnection {
109 agent_id: state.connection.agent_id(),
110 connection,
111 }),
112 AgentConnectionEntry::Connecting { .. } | AgentConnectionEntry::Error { .. } => {
113 None
114 }
115 })
116 .collect()
117 }
118
119 pub fn restart_connection(
120 &mut self,
121 key: Agent,
122 server: Rc<dyn AgentServer>,
123 cx: &mut Context<Self>,
124 ) -> Entity<AgentConnectionEntry> {
125 if let Some(entry) = self.entries.get(&key) {
126 if matches!(entry.read(cx), AgentConnectionEntry::Connecting { .. }) {
127 return entry.clone();
128 }
129 }
130
131 self.entries.remove(&key);
132 self.request_connection(key, server, cx)
133 }
134
135 pub fn request_connection(
136 &mut self,
137 key: Agent,
138 server: Rc<dyn AgentServer>,
139 cx: &mut Context<Self>,
140 ) -> Entity<AgentConnectionEntry> {
141 if let Some(entry) = self.entries.get(&key) {
142 return entry.clone();
143 }
144
145 let (mut new_version_rx, connect_task) = self.start_connection(server, cx);
146 let connect_task = connect_task.shared();
147
148 let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
149 connect_task: connect_task.clone(),
150 });
151
152 self.entries.insert(key.clone(), entry.clone());
153 cx.notify();
154
155 cx.spawn({
156 let key = key.clone();
157 let entry = entry.downgrade();
158 async move |this, cx| match connect_task.await {
159 Ok(connected_state) => {
160 this.update(cx, move |this, cx| {
161 if this.entries.get(&key) != entry.upgrade().as_ref() {
162 return;
163 }
164
165 entry
166 .update(cx, move |entry, cx| {
167 if let AgentConnectionEntry::Connecting { .. } = entry {
168 *entry = AgentConnectionEntry::Connected(connected_state);
169 cx.notify();
170 }
171 })
172 .ok();
173 cx.notify();
174 })
175 .ok();
176 }
177 Err(error) => {
178 this.update(cx, move |this, cx| {
179 if this.entries.get(&key) != entry.upgrade().as_ref() {
180 return;
181 }
182
183 entry
184 .update(cx, move |entry, cx| {
185 if let AgentConnectionEntry::Connecting { .. } = entry {
186 *entry = AgentConnectionEntry::Error { error };
187 cx.notify();
188 }
189 })
190 .ok();
191 this.entries.remove(&key);
192 cx.notify();
193 })
194 .ok();
195 }
196 }
197 })
198 .detach();
199
200 cx.spawn({
201 let entry = entry.downgrade();
202 async move |this, cx| {
203 while let Ok(version) = new_version_rx.recv().await {
204 let Some(version) = version else {
205 continue;
206 };
207
208 this.update(cx, move |this, cx| {
209 if this.entries.get(&key) != entry.upgrade().as_ref() {
210 return;
211 }
212
213 entry
214 .update(cx, move |_entry, cx| {
215 cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
216 version.into(),
217 ));
218 })
219 .ok();
220 this.entries.remove(&key);
221 cx.notify();
222 })
223 .ok();
224 break;
225 }
226 }
227 })
228 .detach();
229
230 entry
231 }
232
233 fn handle_agent_servers_updated(
234 &mut self,
235 store: Entity<AgentServerStore>,
236 _: &AgentServersUpdated,
237 cx: &mut Context<Self>,
238 ) {
239 let store = store.read(cx);
240 self.entries.retain(|key, _| match key {
241 Agent::NativeAgent => true,
242 Agent::Custom { id } => store.external_agents.contains_key(id),
243 #[cfg(any(test, feature = "test-support"))]
244 Agent::Stub => true,
245 });
246 cx.notify();
247 }
248
249 fn start_connection(
250 &self,
251 server: Rc<dyn AgentServer>,
252 cx: &mut Context<Self>,
253 ) -> (
254 Receiver<Option<String>>,
255 Task<Result<AgentConnectedState, LoadError>>,
256 ) {
257 let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
258
259 let agent_server_store = self.project.read(cx).agent_server_store().clone();
260 let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
261
262 let connect_task = server.connect(delegate, self.project.clone(), cx);
263 let connect_task = cx.spawn(async move |_this, _cx| match connect_task.await {
264 Ok(connection) => Ok(AgentConnectedState { connection }),
265 Err(err) => match err.downcast::<LoadError>() {
266 Ok(load_error) => Err(load_error),
267 Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
268 },
269 });
270 (new_version_rx, connect_task)
271 }
272}