1use std::rc::Rc;
2
3use acp_thread::{AgentConnection, LoadError};
4use agent_servers::{AgentServer, AgentServerDelegate};
5use anyhow::Result;
6use collections::HashMap;
7use futures::{FutureExt, future::Shared};
8use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
9
10use project::{AgentServerStore, AgentServersUpdated, Project};
11use watch::Receiver;
12
13use crate::{Agent, ThreadHistory};
14
15pub enum AgentConnectionEntry {
16 Connecting {
17 connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
18 },
19 Connected(AgentConnectedState),
20 Error {
21 error: LoadError,
22 },
23}
24
25#[derive(Clone)]
26pub struct AgentConnectedState {
27 pub connection: Rc<dyn AgentConnection>,
28 pub history: Option<Entity<ThreadHistory>>,
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum AgentConnectionStatus {
33 Disconnected,
34 Connecting,
35 Connected,
36}
37
38impl AgentConnectionEntry {
39 pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
40 match self {
41 AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
42 AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
43 AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
44 }
45 }
46
47 pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
48 match self {
49 AgentConnectionEntry::Connected(state) => state.history.as_ref(),
50 _ => None,
51 }
52 }
53
54 pub fn status(&self) -> AgentConnectionStatus {
55 match self {
56 AgentConnectionEntry::Connecting { .. } => AgentConnectionStatus::Connecting,
57 AgentConnectionEntry::Connected(_) => AgentConnectionStatus::Connected,
58 AgentConnectionEntry::Error { .. } => AgentConnectionStatus::Disconnected,
59 }
60 }
61}
62
63pub enum AgentConnectionEntryEvent {
64 NewVersionAvailable(SharedString),
65}
66
67impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
68
69pub struct AgentConnectionStore {
70 project: Entity<Project>,
71 entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
72 _subscriptions: Vec<Subscription>,
73}
74
75impl AgentConnectionStore {
76 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
77 let agent_server_store = project.read(cx).agent_server_store().clone();
78 let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
79 Self {
80 project,
81 entries: HashMap::default(),
82 _subscriptions: vec![subscription],
83 }
84 }
85
86 pub fn project(&self) -> &Entity<Project> {
87 &self.project
88 }
89
90 pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
91 self.entries.get(key)
92 }
93
94 pub fn connection_status(&self, key: &Agent, cx: &App) -> AgentConnectionStatus {
95 self.entries
96 .get(key)
97 .map(|entry| entry.read(cx).status())
98 .unwrap_or(AgentConnectionStatus::Disconnected)
99 }
100
101 pub fn restart_connection(
102 &mut self,
103 key: Agent,
104 server: Rc<dyn AgentServer>,
105 cx: &mut Context<Self>,
106 ) -> Entity<AgentConnectionEntry> {
107 if let Some(entry) = self.entries.get(&key) {
108 if matches!(entry.read(cx), AgentConnectionEntry::Connecting { .. }) {
109 return entry.clone();
110 }
111 }
112
113 self.entries.remove(&key);
114 self.request_connection(key, server, cx)
115 }
116
117 pub fn request_connection(
118 &mut self,
119 key: Agent,
120 server: Rc<dyn AgentServer>,
121 cx: &mut Context<Self>,
122 ) -> Entity<AgentConnectionEntry> {
123 if let Some(entry) = self.entries.get(&key) {
124 return entry.clone();
125 }
126
127 let (mut new_version_rx, connect_task) = self.start_connection(server, cx);
128 let connect_task = connect_task.shared();
129
130 let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
131 connect_task: connect_task.clone(),
132 });
133
134 self.entries.insert(key.clone(), entry.clone());
135 cx.notify();
136
137 cx.spawn({
138 let key = key.clone();
139 let entry = entry.downgrade();
140 async move |this, cx| match connect_task.await {
141 Ok(connected_state) => {
142 this.update(cx, move |this, cx| {
143 if this.entries.get(&key) != entry.upgrade().as_ref() {
144 return;
145 }
146
147 entry
148 .update(cx, move |entry, cx| {
149 if let AgentConnectionEntry::Connecting { .. } = entry {
150 *entry = AgentConnectionEntry::Connected(connected_state);
151 cx.notify();
152 }
153 })
154 .ok();
155 })
156 .ok();
157 }
158 Err(error) => {
159 this.update(cx, move |this, cx| {
160 if this.entries.get(&key) != entry.upgrade().as_ref() {
161 return;
162 }
163
164 entry
165 .update(cx, move |entry, cx| {
166 if let AgentConnectionEntry::Connecting { .. } = entry {
167 *entry = AgentConnectionEntry::Error { error };
168 cx.notify();
169 }
170 })
171 .ok();
172 this.entries.remove(&key);
173 cx.notify();
174 })
175 .ok();
176 }
177 }
178 })
179 .detach();
180
181 cx.spawn({
182 let entry = entry.downgrade();
183 async move |this, cx| {
184 while let Ok(version) = new_version_rx.recv().await {
185 let Some(version) = version else {
186 continue;
187 };
188
189 this.update(cx, move |this, cx| {
190 if this.entries.get(&key) != entry.upgrade().as_ref() {
191 return;
192 }
193
194 entry
195 .update(cx, move |_entry, cx| {
196 cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
197 version.into(),
198 ));
199 })
200 .ok();
201 this.entries.remove(&key);
202 cx.notify();
203 })
204 .ok();
205 break;
206 }
207 }
208 })
209 .detach();
210
211 entry
212 }
213
214 fn handle_agent_servers_updated(
215 &mut self,
216 store: Entity<AgentServerStore>,
217 _: &AgentServersUpdated,
218 cx: &mut Context<Self>,
219 ) {
220 let store = store.read(cx);
221 self.entries.retain(|key, _| match key {
222 Agent::NativeAgent => true,
223 Agent::Custom { id } => store.external_agents.contains_key(id),
224 });
225 cx.notify();
226 }
227
228 fn start_connection(
229 &self,
230 server: Rc<dyn AgentServer>,
231 cx: &mut Context<Self>,
232 ) -> (
233 Receiver<Option<String>>,
234 Task<Result<AgentConnectedState, LoadError>>,
235 ) {
236 let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
237
238 let agent_server_store = self.project.read(cx).agent_server_store().clone();
239 let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
240
241 let connect_task = server.connect(delegate, self.project.clone(), cx);
242 let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
243 Ok(connection) => cx.update(|cx| {
244 let history = connection
245 .session_list(cx)
246 .map(|session_list| cx.new(|cx| ThreadHistory::new(session_list, cx)));
247 Ok(AgentConnectedState {
248 connection,
249 history,
250 })
251 }),
252 Err(err) => match err.downcast::<LoadError>() {
253 Ok(load_error) => Err(load_error),
254 Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
255 },
256 });
257 (new_version_rx, connect_task)
258 }
259}