1use std::rc::Rc;
2
3use acp_thread::{AgentConnection, LoadError};
4use agent_servers::{AgentServer, AgentServerDelegate};
5use anyhow::Result;
6use collections::HashMap;
7use futures::{FutureExt, future::Shared};
8use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
9
10use project::{AgentServerStore, AgentServersUpdated, Project};
11use watch::Receiver;
12
13use crate::Agent;
14
15pub enum AgentConnectionEntry {
16 Connecting {
17 connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
18 },
19 Connected(AgentConnectedState),
20 Error {
21 error: LoadError,
22 },
23}
24
25#[derive(Clone)]
26pub struct AgentConnectedState {
27 pub connection: Rc<dyn AgentConnection>,
28}
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum AgentConnectionStatus {
32 Disconnected,
33 Connecting,
34 Connected,
35}
36
37impl AgentConnectionEntry {
38 pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
39 match self {
40 AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
41 AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
42 AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
43 }
44 }
45
46 pub fn status(&self) -> AgentConnectionStatus {
47 match self {
48 AgentConnectionEntry::Connecting { .. } => AgentConnectionStatus::Connecting,
49 AgentConnectionEntry::Connected(_) => AgentConnectionStatus::Connected,
50 AgentConnectionEntry::Error { .. } => AgentConnectionStatus::Disconnected,
51 }
52 }
53}
54
55pub enum AgentConnectionEntryEvent {
56 NewVersionAvailable(SharedString),
57}
58
59impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
60
61pub struct AgentConnectionStore {
62 project: Entity<Project>,
63 entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
64 _subscriptions: Vec<Subscription>,
65}
66
67impl AgentConnectionStore {
68 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
69 let agent_server_store = project.read(cx).agent_server_store().clone();
70 let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
71 Self {
72 project,
73 entries: HashMap::default(),
74 _subscriptions: vec![subscription],
75 }
76 }
77
78 pub fn project(&self) -> &Entity<Project> {
79 &self.project
80 }
81
82 pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
83 self.entries.get(key)
84 }
85
86 pub fn connection_status(&self, key: &Agent, cx: &App) -> AgentConnectionStatus {
87 self.entries
88 .get(key)
89 .map(|entry| entry.read(cx).status())
90 .unwrap_or(AgentConnectionStatus::Disconnected)
91 }
92
93 pub fn restart_connection(
94 &mut self,
95 key: Agent,
96 server: Rc<dyn AgentServer>,
97 cx: &mut Context<Self>,
98 ) -> Entity<AgentConnectionEntry> {
99 if let Some(entry) = self.entries.get(&key) {
100 if matches!(entry.read(cx), AgentConnectionEntry::Connecting { .. }) {
101 return entry.clone();
102 }
103 }
104
105 self.entries.remove(&key);
106 self.request_connection(key, server, cx)
107 }
108
109 pub fn request_connection(
110 &mut self,
111 key: Agent,
112 server: Rc<dyn AgentServer>,
113 cx: &mut Context<Self>,
114 ) -> Entity<AgentConnectionEntry> {
115 if let Some(entry) = self.entries.get(&key) {
116 return entry.clone();
117 }
118
119 let (mut new_version_rx, connect_task) = self.start_connection(server, cx);
120 let connect_task = connect_task.shared();
121
122 let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
123 connect_task: connect_task.clone(),
124 });
125
126 self.entries.insert(key.clone(), entry.clone());
127 cx.notify();
128
129 cx.spawn({
130 let key = key.clone();
131 let entry = entry.downgrade();
132 async move |this, cx| match connect_task.await {
133 Ok(connected_state) => {
134 this.update(cx, move |this, cx| {
135 if this.entries.get(&key) != entry.upgrade().as_ref() {
136 return;
137 }
138
139 entry
140 .update(cx, move |entry, cx| {
141 if let AgentConnectionEntry::Connecting { .. } = entry {
142 *entry = AgentConnectionEntry::Connected(connected_state);
143 cx.notify();
144 }
145 })
146 .ok();
147 })
148 .ok();
149 }
150 Err(error) => {
151 this.update(cx, move |this, cx| {
152 if this.entries.get(&key) != entry.upgrade().as_ref() {
153 return;
154 }
155
156 entry
157 .update(cx, move |entry, cx| {
158 if let AgentConnectionEntry::Connecting { .. } = entry {
159 *entry = AgentConnectionEntry::Error { error };
160 cx.notify();
161 }
162 })
163 .ok();
164 this.entries.remove(&key);
165 cx.notify();
166 })
167 .ok();
168 }
169 }
170 })
171 .detach();
172
173 cx.spawn({
174 let entry = entry.downgrade();
175 async move |this, cx| {
176 while let Ok(version) = new_version_rx.recv().await {
177 let Some(version) = version else {
178 continue;
179 };
180
181 this.update(cx, move |this, cx| {
182 if this.entries.get(&key) != entry.upgrade().as_ref() {
183 return;
184 }
185
186 entry
187 .update(cx, move |_entry, cx| {
188 cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
189 version.into(),
190 ));
191 })
192 .ok();
193 this.entries.remove(&key);
194 cx.notify();
195 })
196 .ok();
197 break;
198 }
199 }
200 })
201 .detach();
202
203 entry
204 }
205
206 fn handle_agent_servers_updated(
207 &mut self,
208 store: Entity<AgentServerStore>,
209 _: &AgentServersUpdated,
210 cx: &mut Context<Self>,
211 ) {
212 let store = store.read(cx);
213 self.entries.retain(|key, _| match key {
214 Agent::NativeAgent => true,
215 Agent::Custom { id } => store.external_agents.contains_key(id),
216 #[cfg(any(test, feature = "test-support"))]
217 Agent::Stub => true,
218 });
219 cx.notify();
220 }
221
222 fn start_connection(
223 &self,
224 server: Rc<dyn AgentServer>,
225 cx: &mut Context<Self>,
226 ) -> (
227 Receiver<Option<String>>,
228 Task<Result<AgentConnectedState, LoadError>>,
229 ) {
230 let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
231
232 let agent_server_store = self.project.read(cx).agent_server_store().clone();
233 let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
234
235 let connect_task = server.connect(delegate, self.project.clone(), cx);
236 let connect_task = cx.spawn(async move |_this, _cx| match connect_task.await {
237 Ok(connection) => Ok(AgentConnectedState { connection }),
238 Err(err) => match err.downcast::<LoadError>() {
239 Ok(load_error) => Err(load_error),
240 Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
241 },
242 });
243 (new_version_rx, connect_task)
244 }
245}