agent_connection_store.rs

  1use std::rc::Rc;
  2
  3use acp_thread::{AgentConnection, LoadError};
  4use agent_client_protocol as acp;
  5use agent_servers::{AgentServer, AgentServerDelegate};
  6use anyhow::Result;
  7use collections::HashMap;
  8use futures::{FutureExt, future::Shared};
  9use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
 10
 11use project::{AgentServerStore, AgentServersUpdated, Project};
 12use watch::Receiver;
 13
 14use crate::{Agent, ThreadHistory};
 15
 16pub enum AgentConnectionEntry {
 17    Connecting {
 18        connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
 19    },
 20    Connected(AgentConnectedState),
 21    Error {
 22        error: LoadError,
 23    },
 24}
 25
 26#[derive(Clone)]
 27pub struct AgentConnectedState {
 28    pub connection: Rc<dyn AgentConnection>,
 29    pub history: Option<Entity<ThreadHistory>>,
 30}
 31
 32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 33pub enum AgentConnectionStatus {
 34    Disconnected,
 35    Connecting,
 36    Connected,
 37}
 38
 39impl AgentConnectionEntry {
 40    pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
 41        match self {
 42            AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
 43            AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
 44            AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
 45        }
 46    }
 47
 48    pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
 49        match self {
 50            AgentConnectionEntry::Connected(state) => state.history.as_ref(),
 51            _ => None,
 52        }
 53    }
 54
 55    pub fn status(&self) -> AgentConnectionStatus {
 56        match self {
 57            AgentConnectionEntry::Connecting { .. } => AgentConnectionStatus::Connecting,
 58            AgentConnectionEntry::Connected(_) => AgentConnectionStatus::Connected,
 59            AgentConnectionEntry::Error { .. } => AgentConnectionStatus::Disconnected,
 60        }
 61    }
 62}
 63
 64pub enum AgentConnectionEntryEvent {
 65    NewVersionAvailable(SharedString),
 66}
 67
 68impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
 69
 70#[derive(Clone)]
 71pub struct ActiveAcpConnection {
 72    pub agent_id: project::AgentId,
 73    pub connection: Rc<acp::ClientSideConnection>,
 74}
 75
 76pub struct AgentConnectionStore {
 77    project: Entity<Project>,
 78    entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
 79    _subscriptions: Vec<Subscription>,
 80}
 81
 82impl AgentConnectionStore {
 83    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 84        let agent_server_store = project.read(cx).agent_server_store().clone();
 85        let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
 86        Self {
 87            project,
 88            entries: HashMap::default(),
 89            _subscriptions: vec![subscription],
 90        }
 91    }
 92
 93    pub fn project(&self) -> &Entity<Project> {
 94        &self.project
 95    }
 96
 97    pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
 98        self.entries.get(key)
 99    }
100
101    pub fn connection_status(&self, key: &Agent, cx: &App) -> AgentConnectionStatus {
102        self.entries
103            .get(key)
104            .map(|entry| entry.read(cx).status())
105            .unwrap_or(AgentConnectionStatus::Disconnected)
106    }
107
108    pub fn active_acp_connections(&self, cx: &App) -> Vec<ActiveAcpConnection> {
109        self.entries
110            .values()
111            .filter_map(|entry| match entry.read(cx) {
112                AgentConnectionEntry::Connected(state) => state
113                    .connection
114                    .clone()
115                    .downcast::<agent_servers::AcpConnection>()
116                    .map(|connection| ActiveAcpConnection {
117                        agent_id: state.connection.agent_id(),
118                        connection: connection.client_connection().clone(),
119                    }),
120                AgentConnectionEntry::Connecting { .. } | AgentConnectionEntry::Error { .. } => {
121                    None
122                }
123            })
124            .collect()
125    }
126
127    pub fn restart_connection(
128        &mut self,
129        key: Agent,
130        server: Rc<dyn AgentServer>,
131        cx: &mut Context<Self>,
132    ) -> Entity<AgentConnectionEntry> {
133        if let Some(entry) = self.entries.get(&key) {
134            if matches!(entry.read(cx), AgentConnectionEntry::Connecting { .. }) {
135                return entry.clone();
136            }
137        }
138
139        self.entries.remove(&key);
140        self.request_connection(key, server, cx)
141    }
142
143    pub fn request_connection(
144        &mut self,
145        key: Agent,
146        server: Rc<dyn AgentServer>,
147        cx: &mut Context<Self>,
148    ) -> Entity<AgentConnectionEntry> {
149        if let Some(entry) = self.entries.get(&key) {
150            return entry.clone();
151        }
152
153        let (mut new_version_rx, connect_task) = self.start_connection(server, cx);
154        let connect_task = connect_task.shared();
155
156        let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
157            connect_task: connect_task.clone(),
158        });
159
160        self.entries.insert(key.clone(), entry.clone());
161        cx.notify();
162
163        cx.spawn({
164            let key = key.clone();
165            let entry = entry.downgrade();
166            async move |this, cx| match connect_task.await {
167                Ok(connected_state) => {
168                    this.update(cx, move |this, cx| {
169                        if this.entries.get(&key) != entry.upgrade().as_ref() {
170                            return;
171                        }
172
173                        entry
174                            .update(cx, move |entry, cx| {
175                                if let AgentConnectionEntry::Connecting { .. } = entry {
176                                    *entry = AgentConnectionEntry::Connected(connected_state);
177                                    cx.notify();
178                                }
179                            })
180                            .ok();
181                        cx.notify();
182                    })
183                    .ok();
184                }
185                Err(error) => {
186                    this.update(cx, move |this, cx| {
187                        if this.entries.get(&key) != entry.upgrade().as_ref() {
188                            return;
189                        }
190
191                        entry
192                            .update(cx, move |entry, cx| {
193                                if let AgentConnectionEntry::Connecting { .. } = entry {
194                                    *entry = AgentConnectionEntry::Error { error };
195                                    cx.notify();
196                                }
197                            })
198                            .ok();
199                        this.entries.remove(&key);
200                        cx.notify();
201                    })
202                    .ok();
203                }
204            }
205        })
206        .detach();
207
208        cx.spawn({
209            let entry = entry.downgrade();
210            async move |this, cx| {
211                while let Ok(version) = new_version_rx.recv().await {
212                    let Some(version) = version else {
213                        continue;
214                    };
215
216                    this.update(cx, move |this, cx| {
217                        if this.entries.get(&key) != entry.upgrade().as_ref() {
218                            return;
219                        }
220
221                        entry
222                            .update(cx, move |_entry, cx| {
223                                cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
224                                    version.into(),
225                                ));
226                            })
227                            .ok();
228                        this.entries.remove(&key);
229                        cx.notify();
230                    })
231                    .ok();
232                    break;
233                }
234            }
235        })
236        .detach();
237
238        entry
239    }
240
241    fn handle_agent_servers_updated(
242        &mut self,
243        store: Entity<AgentServerStore>,
244        _: &AgentServersUpdated,
245        cx: &mut Context<Self>,
246    ) {
247        let external_agent_ids: Vec<_> = {
248            let store = store.read(cx);
249            store.external_agents.keys().cloned().collect()
250        };
251
252        self.entries.retain(|key, _| match key {
253            Agent::NativeAgent => true,
254            Agent::Custom { id } => external_agent_ids.contains(id),
255        });
256        cx.notify();
257    }
258
259    fn start_connection(
260        &self,
261        server: Rc<dyn AgentServer>,
262        cx: &mut Context<Self>,
263    ) -> (
264        Receiver<Option<String>>,
265        Task<Result<AgentConnectedState, LoadError>>,
266    ) {
267        let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
268
269        let agent_server_store = self.project.read(cx).agent_server_store().clone();
270        let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
271
272        let connect_task = server.connect(delegate, self.project.clone(), cx);
273        let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
274            Ok(connection) => cx.update(|cx| {
275                let history = connection
276                    .session_list(cx)
277                    .map(|session_list| cx.new(|cx| ThreadHistory::new(session_list, cx)));
278                Ok(AgentConnectedState {
279                    connection,
280                    history,
281                })
282            }),
283            Err(err) => match err.downcast::<LoadError>() {
284                Ok(load_error) => Err(load_error),
285                Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
286            },
287        });
288        (new_version_rx, connect_task)
289    }
290}