agent_connection_store.rs

  1use std::rc::Rc;
  2
  3use acp_thread::{AgentConnection, LoadError};
  4use agent_servers::AcpConnection;
  5use agent_servers::{AgentServer, AgentServerDelegate};
  6use anyhow::Result;
  7use collections::HashMap;
  8use futures::{FutureExt, future::Shared};
  9use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
 10
 11use project::{AgentServerStore, AgentServersUpdated, Project};
 12use watch::Receiver;
 13
 14use crate::Agent;
 15
 16pub enum AgentConnectionEntry {
 17    Connecting {
 18        connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
 19    },
 20    Connected(AgentConnectedState),
 21    Error {
 22        error: LoadError,
 23    },
 24}
 25
 26#[derive(Clone)]
 27pub struct AgentConnectedState {
 28    pub connection: Rc<dyn AgentConnection>,
 29}
 30
 31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 32pub enum AgentConnectionStatus {
 33    Disconnected,
 34    Connecting,
 35    Connected,
 36}
 37
 38impl AgentConnectionEntry {
 39    pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
 40        match self {
 41            AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
 42            AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
 43            AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
 44        }
 45    }
 46
 47    pub fn status(&self) -> AgentConnectionStatus {
 48        match self {
 49            AgentConnectionEntry::Connecting { .. } => AgentConnectionStatus::Connecting,
 50            AgentConnectionEntry::Connected(_) => AgentConnectionStatus::Connected,
 51            AgentConnectionEntry::Error { .. } => AgentConnectionStatus::Disconnected,
 52        }
 53    }
 54}
 55
 56pub enum AgentConnectionEntryEvent {
 57    NewVersionAvailable(SharedString),
 58}
 59
 60impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
 61
 62#[derive(Clone)]
 63pub struct ActiveAcpConnection {
 64    pub agent_id: project::AgentId,
 65    pub connection: Rc<AcpConnection>,
 66}
 67
 68pub struct AgentConnectionStore {
 69    project: Entity<Project>,
 70    entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
 71    _subscriptions: Vec<Subscription>,
 72}
 73
 74impl AgentConnectionStore {
 75    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 76        let agent_server_store = project.read(cx).agent_server_store().clone();
 77        let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
 78        Self {
 79            project,
 80            entries: HashMap::default(),
 81            _subscriptions: vec![subscription],
 82        }
 83    }
 84
 85    pub fn project(&self) -> &Entity<Project> {
 86        &self.project
 87    }
 88
 89    pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
 90        self.entries.get(key)
 91    }
 92
 93    pub fn connection_status(&self, key: &Agent, cx: &App) -> AgentConnectionStatus {
 94        self.entries
 95            .get(key)
 96            .map(|entry| entry.read(cx).status())
 97            .unwrap_or(AgentConnectionStatus::Disconnected)
 98    }
 99
100    pub fn active_acp_connections(&self, cx: &App) -> Vec<ActiveAcpConnection> {
101        self.entries
102            .values()
103            .filter_map(|entry| match entry.read(cx) {
104                AgentConnectionEntry::Connected(state) => state
105                    .connection
106                    .clone()
107                    .downcast::<AcpConnection>()
108                    .map(|connection| ActiveAcpConnection {
109                        agent_id: state.connection.agent_id(),
110                        connection,
111                    }),
112                AgentConnectionEntry::Connecting { .. } | AgentConnectionEntry::Error { .. } => {
113                    None
114                }
115            })
116            .collect()
117    }
118
119    pub fn restart_connection(
120        &mut self,
121        key: Agent,
122        server: Rc<dyn AgentServer>,
123        cx: &mut Context<Self>,
124    ) -> Entity<AgentConnectionEntry> {
125        if let Some(entry) = self.entries.get(&key) {
126            if matches!(entry.read(cx), AgentConnectionEntry::Connecting { .. }) {
127                return entry.clone();
128            }
129        }
130
131        self.entries.remove(&key);
132        self.request_connection(key, server, cx)
133    }
134
135    pub fn request_connection(
136        &mut self,
137        key: Agent,
138        server: Rc<dyn AgentServer>,
139        cx: &mut Context<Self>,
140    ) -> Entity<AgentConnectionEntry> {
141        if let Some(entry) = self.entries.get(&key) {
142            return entry.clone();
143        }
144
145        let (mut new_version_rx, connect_task) = self.start_connection(server, cx);
146        let connect_task = connect_task.shared();
147
148        let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
149            connect_task: connect_task.clone(),
150        });
151
152        self.entries.insert(key.clone(), entry.clone());
153        cx.notify();
154
155        cx.spawn({
156            let key = key.clone();
157            let entry = entry.downgrade();
158            async move |this, cx| match connect_task.await {
159                Ok(connected_state) => {
160                    this.update(cx, move |this, cx| {
161                        if this.entries.get(&key) != entry.upgrade().as_ref() {
162                            return;
163                        }
164
165                        entry
166                            .update(cx, move |entry, cx| {
167                                if let AgentConnectionEntry::Connecting { .. } = entry {
168                                    *entry = AgentConnectionEntry::Connected(connected_state);
169                                    cx.notify();
170                                }
171                            })
172                            .ok();
173                        cx.notify();
174                    })
175                    .ok();
176                }
177                Err(error) => {
178                    this.update(cx, move |this, cx| {
179                        if this.entries.get(&key) != entry.upgrade().as_ref() {
180                            return;
181                        }
182
183                        entry
184                            .update(cx, move |entry, cx| {
185                                if let AgentConnectionEntry::Connecting { .. } = entry {
186                                    *entry = AgentConnectionEntry::Error { error };
187                                    cx.notify();
188                                }
189                            })
190                            .ok();
191                        this.entries.remove(&key);
192                        cx.notify();
193                    })
194                    .ok();
195                }
196            }
197        })
198        .detach();
199
200        cx.spawn({
201            let entry = entry.downgrade();
202            async move |this, cx| {
203                while let Ok(version) = new_version_rx.recv().await {
204                    let Some(version) = version else {
205                        continue;
206                    };
207
208                    this.update(cx, move |this, cx| {
209                        if this.entries.get(&key) != entry.upgrade().as_ref() {
210                            return;
211                        }
212
213                        entry
214                            .update(cx, move |_entry, cx| {
215                                cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
216                                    version.into(),
217                                ));
218                            })
219                            .ok();
220                        this.entries.remove(&key);
221                        cx.notify();
222                    })
223                    .ok();
224                    break;
225                }
226            }
227        })
228        .detach();
229
230        entry
231    }
232
233    fn handle_agent_servers_updated(
234        &mut self,
235        store: Entity<AgentServerStore>,
236        _: &AgentServersUpdated,
237        cx: &mut Context<Self>,
238    ) {
239        let store = store.read(cx);
240        self.entries.retain(|key, _| match key {
241            Agent::NativeAgent => true,
242            Agent::Custom { id } => store.external_agents.contains_key(id),
243            #[cfg(any(test, feature = "test-support"))]
244            Agent::Stub => true,
245        });
246        cx.notify();
247    }
248
249    fn start_connection(
250        &self,
251        server: Rc<dyn AgentServer>,
252        cx: &mut Context<Self>,
253    ) -> (
254        Receiver<Option<String>>,
255        Task<Result<AgentConnectedState, LoadError>>,
256    ) {
257        let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
258
259        let agent_server_store = self.project.read(cx).agent_server_store().clone();
260        let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
261
262        let connect_task = server.connect(delegate, self.project.clone(), cx);
263        let connect_task = cx.spawn(async move |_this, _cx| match connect_task.await {
264            Ok(connection) => Ok(AgentConnectedState { connection }),
265            Err(err) => match err.downcast::<LoadError>() {
266                Ok(load_error) => Err(load_error),
267                Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
268            },
269        });
270        (new_version_rx, connect_task)
271    }
272}