1use std::rc::Rc;
2
3use acp_thread::{AgentConnection, LoadError};
4use agent_servers::{AgentServer, AgentServerDelegate};
5use anyhow::Result;
6use collections::HashMap;
7use futures::{FutureExt, future::Shared};
8use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
9
10use project::{AgentServerStore, AgentServersUpdated, Project};
11use watch::Receiver;
12
13use crate::{Agent, ThreadHistory};
14
15pub enum AgentConnectionEntry {
16 Connecting {
17 connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
18 },
19 Connected(AgentConnectedState),
20 Error {
21 error: LoadError,
22 },
23}
24
25#[derive(Clone)]
26pub struct AgentConnectedState {
27 pub connection: Rc<dyn AgentConnection>,
28 pub history: Option<Entity<ThreadHistory>>,
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum AgentConnectionStatus {
33 Disconnected,
34 Connecting,
35 Connected,
36}
37
38impl AgentConnectionEntry {
39 pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
40 match self {
41 AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
42 AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
43 AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
44 }
45 }
46
47 pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
48 match self {
49 AgentConnectionEntry::Connected(state) => state.history.as_ref(),
50 _ => None,
51 }
52 }
53
54 pub fn status(&self) -> AgentConnectionStatus {
55 match self {
56 AgentConnectionEntry::Connecting { .. } => AgentConnectionStatus::Connecting,
57 AgentConnectionEntry::Connected(_) => AgentConnectionStatus::Connected,
58 AgentConnectionEntry::Error { .. } => AgentConnectionStatus::Disconnected,
59 }
60 }
61}
62
63pub enum AgentConnectionEntryEvent {
64 NewVersionAvailable(SharedString),
65}
66
67impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
68
69pub struct AgentConnectionStore {
70 project: Entity<Project>,
71 entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
72 _subscriptions: Vec<Subscription>,
73}
74
75impl AgentConnectionStore {
76 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
77 let agent_server_store = project.read(cx).agent_server_store().clone();
78 let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
79 Self {
80 project,
81 entries: HashMap::default(),
82 _subscriptions: vec![subscription],
83 }
84 }
85
86 pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
87 self.entries.get(key)
88 }
89
90 pub fn connection_status(&self, key: &Agent, cx: &App) -> AgentConnectionStatus {
91 self.entries
92 .get(key)
93 .map(|entry| entry.read(cx).status())
94 .unwrap_or(AgentConnectionStatus::Disconnected)
95 }
96
97 pub fn restart_connection(
98 &mut self,
99 key: Agent,
100 server: Rc<dyn AgentServer>,
101 cx: &mut Context<Self>,
102 ) -> Entity<AgentConnectionEntry> {
103 if let Some(entry) = self.entries.get(&key) {
104 if matches!(entry.read(cx), AgentConnectionEntry::Connecting { .. }) {
105 return entry.clone();
106 }
107 }
108
109 self.entries.remove(&key);
110 self.request_connection(key, server, cx)
111 }
112
113 pub fn request_connection(
114 &mut self,
115 key: Agent,
116 server: Rc<dyn AgentServer>,
117 cx: &mut Context<Self>,
118 ) -> Entity<AgentConnectionEntry> {
119 if let Some(entry) = self.entries.get(&key) {
120 return entry.clone();
121 }
122
123 let (mut new_version_rx, connect_task) = self.start_connection(server, cx);
124 let connect_task = connect_task.shared();
125
126 let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
127 connect_task: connect_task.clone(),
128 });
129
130 self.entries.insert(key.clone(), entry.clone());
131 cx.notify();
132
133 cx.spawn({
134 let key = key.clone();
135 let entry = entry.downgrade();
136 async move |this, cx| match connect_task.await {
137 Ok(connected_state) => {
138 this.update(cx, move |this, cx| {
139 if this.entries.get(&key) != entry.upgrade().as_ref() {
140 return;
141 }
142
143 entry
144 .update(cx, move |entry, cx| {
145 if let AgentConnectionEntry::Connecting { .. } = entry {
146 *entry = AgentConnectionEntry::Connected(connected_state);
147 cx.notify();
148 }
149 })
150 .ok();
151 })
152 .ok();
153 }
154 Err(error) => {
155 this.update(cx, move |this, cx| {
156 if this.entries.get(&key) != entry.upgrade().as_ref() {
157 return;
158 }
159
160 entry
161 .update(cx, move |entry, cx| {
162 if let AgentConnectionEntry::Connecting { .. } = entry {
163 *entry = AgentConnectionEntry::Error { error };
164 cx.notify();
165 }
166 })
167 .ok();
168 this.entries.remove(&key);
169 cx.notify();
170 })
171 .ok();
172 }
173 }
174 })
175 .detach();
176
177 cx.spawn({
178 let entry = entry.downgrade();
179 async move |this, cx| {
180 while let Ok(version) = new_version_rx.recv().await {
181 let Some(version) = version else {
182 continue;
183 };
184
185 this.update(cx, move |this, cx| {
186 if this.entries.get(&key) != entry.upgrade().as_ref() {
187 return;
188 }
189
190 entry
191 .update(cx, move |_entry, cx| {
192 cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
193 version.into(),
194 ));
195 })
196 .ok();
197 this.entries.remove(&key);
198 cx.notify();
199 })
200 .ok();
201 break;
202 }
203 }
204 })
205 .detach();
206
207 entry
208 }
209
210 fn handle_agent_servers_updated(
211 &mut self,
212 store: Entity<AgentServerStore>,
213 _: &AgentServersUpdated,
214 cx: &mut Context<Self>,
215 ) {
216 let store = store.read(cx);
217 self.entries.retain(|key, _| match key {
218 Agent::NativeAgent => true,
219 Agent::Custom { id } => store.external_agents.contains_key(id),
220 });
221 cx.notify();
222 }
223
224 fn start_connection(
225 &self,
226 server: Rc<dyn AgentServer>,
227 cx: &mut Context<Self>,
228 ) -> (
229 Receiver<Option<String>>,
230 Task<Result<AgentConnectedState, LoadError>>,
231 ) {
232 let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
233
234 let agent_server_store = self.project.read(cx).agent_server_store().clone();
235 let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
236
237 let connect_task = server.connect(delegate, self.project.clone(), cx);
238 let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
239 Ok(connection) => cx.update(|cx| {
240 let history = connection
241 .session_list(cx)
242 .map(|session_list| cx.new(|cx| ThreadHistory::new(session_list, cx)));
243 Ok(AgentConnectedState {
244 connection,
245 history,
246 })
247 }),
248 Err(err) => match err.downcast::<LoadError>() {
249 Ok(load_error) => Err(load_error),
250 Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
251 },
252 });
253 (new_version_rx, connect_task)
254 }
255}