agent_connection_store.rs

  1use std::rc::Rc;
  2
  3use acp_thread::{AgentConnection, LoadError};
  4use agent_servers::{AgentServer, AgentServerDelegate};
  5use anyhow::Result;
  6use collections::HashMap;
  7use futures::{FutureExt, future::Shared};
  8use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
  9
 10use project::{AgentServerStore, AgentServersUpdated, Project};
 11use watch::Receiver;
 12
 13use crate::Agent;
 14
 15pub enum AgentConnectionEntry {
 16    Connecting {
 17        connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
 18    },
 19    Connected(AgentConnectedState),
 20    Error {
 21        error: LoadError,
 22    },
 23}
 24
 25#[derive(Clone)]
 26pub struct AgentConnectedState {
 27    pub connection: Rc<dyn AgentConnection>,
 28}
 29
 30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 31pub enum AgentConnectionStatus {
 32    Disconnected,
 33    Connecting,
 34    Connected,
 35}
 36
 37impl AgentConnectionEntry {
 38    pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
 39        match self {
 40            AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
 41            AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
 42            AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
 43        }
 44    }
 45
 46    pub fn status(&self) -> AgentConnectionStatus {
 47        match self {
 48            AgentConnectionEntry::Connecting { .. } => AgentConnectionStatus::Connecting,
 49            AgentConnectionEntry::Connected(_) => AgentConnectionStatus::Connected,
 50            AgentConnectionEntry::Error { .. } => AgentConnectionStatus::Disconnected,
 51        }
 52    }
 53}
 54
 55pub enum AgentConnectionEntryEvent {
 56    NewVersionAvailable(SharedString),
 57}
 58
 59impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
 60
 61pub struct AgentConnectionStore {
 62    project: Entity<Project>,
 63    entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
 64    _subscriptions: Vec<Subscription>,
 65}
 66
 67impl AgentConnectionStore {
 68    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 69        let agent_server_store = project.read(cx).agent_server_store().clone();
 70        let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
 71        Self {
 72            project,
 73            entries: HashMap::default(),
 74            _subscriptions: vec![subscription],
 75        }
 76    }
 77
 78    pub fn project(&self) -> &Entity<Project> {
 79        &self.project
 80    }
 81
 82    pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
 83        self.entries.get(key)
 84    }
 85
 86    pub fn connection_status(&self, key: &Agent, cx: &App) -> AgentConnectionStatus {
 87        self.entries
 88            .get(key)
 89            .map(|entry| entry.read(cx).status())
 90            .unwrap_or(AgentConnectionStatus::Disconnected)
 91    }
 92
 93    pub fn restart_connection(
 94        &mut self,
 95        key: Agent,
 96        server: Rc<dyn AgentServer>,
 97        cx: &mut Context<Self>,
 98    ) -> Entity<AgentConnectionEntry> {
 99        if let Some(entry) = self.entries.get(&key) {
100            if matches!(entry.read(cx), AgentConnectionEntry::Connecting { .. }) {
101                return entry.clone();
102            }
103        }
104
105        self.entries.remove(&key);
106        self.request_connection(key, server, cx)
107    }
108
109    pub fn request_connection(
110        &mut self,
111        key: Agent,
112        server: Rc<dyn AgentServer>,
113        cx: &mut Context<Self>,
114    ) -> Entity<AgentConnectionEntry> {
115        if let Some(entry) = self.entries.get(&key) {
116            return entry.clone();
117        }
118
119        let (mut new_version_rx, connect_task) = self.start_connection(server, cx);
120        let connect_task = connect_task.shared();
121
122        let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
123            connect_task: connect_task.clone(),
124        });
125
126        self.entries.insert(key.clone(), entry.clone());
127        cx.notify();
128
129        cx.spawn({
130            let key = key.clone();
131            let entry = entry.downgrade();
132            async move |this, cx| match connect_task.await {
133                Ok(connected_state) => {
134                    this.update(cx, move |this, cx| {
135                        if this.entries.get(&key) != entry.upgrade().as_ref() {
136                            return;
137                        }
138
139                        entry
140                            .update(cx, move |entry, cx| {
141                                if let AgentConnectionEntry::Connecting { .. } = entry {
142                                    *entry = AgentConnectionEntry::Connected(connected_state);
143                                    cx.notify();
144                                }
145                            })
146                            .ok();
147                    })
148                    .ok();
149                }
150                Err(error) => {
151                    this.update(cx, move |this, cx| {
152                        if this.entries.get(&key) != entry.upgrade().as_ref() {
153                            return;
154                        }
155
156                        entry
157                            .update(cx, move |entry, cx| {
158                                if let AgentConnectionEntry::Connecting { .. } = entry {
159                                    *entry = AgentConnectionEntry::Error { error };
160                                    cx.notify();
161                                }
162                            })
163                            .ok();
164                        this.entries.remove(&key);
165                        cx.notify();
166                    })
167                    .ok();
168                }
169            }
170        })
171        .detach();
172
173        cx.spawn({
174            let entry = entry.downgrade();
175            async move |this, cx| {
176                while let Ok(version) = new_version_rx.recv().await {
177                    let Some(version) = version else {
178                        continue;
179                    };
180
181                    this.update(cx, move |this, cx| {
182                        if this.entries.get(&key) != entry.upgrade().as_ref() {
183                            return;
184                        }
185
186                        entry
187                            .update(cx, move |_entry, cx| {
188                                cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
189                                    version.into(),
190                                ));
191                            })
192                            .ok();
193                        this.entries.remove(&key);
194                        cx.notify();
195                    })
196                    .ok();
197                    break;
198                }
199            }
200        })
201        .detach();
202
203        entry
204    }
205
206    fn handle_agent_servers_updated(
207        &mut self,
208        store: Entity<AgentServerStore>,
209        _: &AgentServersUpdated,
210        cx: &mut Context<Self>,
211    ) {
212        let store = store.read(cx);
213        self.entries.retain(|key, _| match key {
214            Agent::NativeAgent => true,
215            Agent::Custom { id } => store.external_agents.contains_key(id),
216            #[cfg(any(test, feature = "test-support"))]
217            Agent::Stub => true,
218        });
219        cx.notify();
220    }
221
222    fn start_connection(
223        &self,
224        server: Rc<dyn AgentServer>,
225        cx: &mut Context<Self>,
226    ) -> (
227        Receiver<Option<String>>,
228        Task<Result<AgentConnectedState, LoadError>>,
229    ) {
230        let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
231
232        let agent_server_store = self.project.read(cx).agent_server_store().clone();
233        let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
234
235        let connect_task = server.connect(delegate, self.project.clone(), cx);
236        let connect_task = cx.spawn(async move |_this, _cx| match connect_task.await {
237            Ok(connection) => Ok(AgentConnectedState { connection }),
238            Err(err) => match err.downcast::<LoadError>() {
239                Ok(load_error) => Err(load_error),
240                Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
241            },
242        });
243        (new_version_rx, connect_task)
244    }
245}