agent_connection_store.rs

  1use std::rc::Rc;
  2
  3use acp_thread::{AgentConnection, LoadError};
  4use agent_servers::{AgentServer, AgentServerDelegate};
  5use anyhow::Result;
  6use collections::HashMap;
  7use futures::{FutureExt, future::Shared};
  8use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
  9
 10use project::{AgentServerStore, AgentServersUpdated, Project};
 11use watch::Receiver;
 12
 13use crate::{Agent, ThreadHistory};
 14
 15pub enum AgentConnectionEntry {
 16    Connecting {
 17        connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
 18    },
 19    Connected(AgentConnectedState),
 20    Error {
 21        error: LoadError,
 22    },
 23}
 24
 25#[derive(Clone)]
 26pub struct AgentConnectedState {
 27    pub connection: Rc<dyn AgentConnection>,
 28    pub history: Option<Entity<ThreadHistory>>,
 29}
 30
 31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
 32pub enum AgentConnectionStatus {
 33    Disconnected,
 34    Connecting,
 35    Connected,
 36}
 37
 38impl AgentConnectionEntry {
 39    pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
 40        match self {
 41            AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
 42            AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
 43            AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
 44        }
 45    }
 46
 47    pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
 48        match self {
 49            AgentConnectionEntry::Connected(state) => state.history.as_ref(),
 50            _ => None,
 51        }
 52    }
 53
 54    pub fn status(&self) -> AgentConnectionStatus {
 55        match self {
 56            AgentConnectionEntry::Connecting { .. } => AgentConnectionStatus::Connecting,
 57            AgentConnectionEntry::Connected(_) => AgentConnectionStatus::Connected,
 58            AgentConnectionEntry::Error { .. } => AgentConnectionStatus::Disconnected,
 59        }
 60    }
 61}
 62
 63pub enum AgentConnectionEntryEvent {
 64    NewVersionAvailable(SharedString),
 65}
 66
 67impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
 68
 69pub struct AgentConnectionStore {
 70    project: Entity<Project>,
 71    entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
 72    _subscriptions: Vec<Subscription>,
 73}
 74
 75impl AgentConnectionStore {
 76    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 77        let agent_server_store = project.read(cx).agent_server_store().clone();
 78        let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
 79        Self {
 80            project,
 81            entries: HashMap::default(),
 82            _subscriptions: vec![subscription],
 83        }
 84    }
 85
 86    pub fn project(&self) -> &Entity<Project> {
 87        &self.project
 88    }
 89
 90    pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
 91        self.entries.get(key)
 92    }
 93
 94    pub fn connection_status(&self, key: &Agent, cx: &App) -> AgentConnectionStatus {
 95        self.entries
 96            .get(key)
 97            .map(|entry| entry.read(cx).status())
 98            .unwrap_or(AgentConnectionStatus::Disconnected)
 99    }
100
101    pub fn restart_connection(
102        &mut self,
103        key: Agent,
104        server: Rc<dyn AgentServer>,
105        cx: &mut Context<Self>,
106    ) -> Entity<AgentConnectionEntry> {
107        if let Some(entry) = self.entries.get(&key) {
108            if matches!(entry.read(cx), AgentConnectionEntry::Connecting { .. }) {
109                return entry.clone();
110            }
111        }
112
113        self.entries.remove(&key);
114        self.request_connection(key, server, cx)
115    }
116
117    pub fn request_connection(
118        &mut self,
119        key: Agent,
120        server: Rc<dyn AgentServer>,
121        cx: &mut Context<Self>,
122    ) -> Entity<AgentConnectionEntry> {
123        if let Some(entry) = self.entries.get(&key) {
124            return entry.clone();
125        }
126
127        let (mut new_version_rx, connect_task) = self.start_connection(server, cx);
128        let connect_task = connect_task.shared();
129
130        let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
131            connect_task: connect_task.clone(),
132        });
133
134        self.entries.insert(key.clone(), entry.clone());
135        cx.notify();
136
137        cx.spawn({
138            let key = key.clone();
139            let entry = entry.downgrade();
140            async move |this, cx| match connect_task.await {
141                Ok(connected_state) => {
142                    this.update(cx, move |this, cx| {
143                        if this.entries.get(&key) != entry.upgrade().as_ref() {
144                            return;
145                        }
146
147                        entry
148                            .update(cx, move |entry, cx| {
149                                if let AgentConnectionEntry::Connecting { .. } = entry {
150                                    *entry = AgentConnectionEntry::Connected(connected_state);
151                                    cx.notify();
152                                }
153                            })
154                            .ok();
155                    })
156                    .ok();
157                }
158                Err(error) => {
159                    this.update(cx, move |this, cx| {
160                        if this.entries.get(&key) != entry.upgrade().as_ref() {
161                            return;
162                        }
163
164                        entry
165                            .update(cx, move |entry, cx| {
166                                if let AgentConnectionEntry::Connecting { .. } = entry {
167                                    *entry = AgentConnectionEntry::Error { error };
168                                    cx.notify();
169                                }
170                            })
171                            .ok();
172                        this.entries.remove(&key);
173                        cx.notify();
174                    })
175                    .ok();
176                }
177            }
178        })
179        .detach();
180
181        cx.spawn({
182            let entry = entry.downgrade();
183            async move |this, cx| {
184                while let Ok(version) = new_version_rx.recv().await {
185                    let Some(version) = version else {
186                        continue;
187                    };
188
189                    this.update(cx, move |this, cx| {
190                        if this.entries.get(&key) != entry.upgrade().as_ref() {
191                            return;
192                        }
193
194                        entry
195                            .update(cx, move |_entry, cx| {
196                                cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
197                                    version.into(),
198                                ));
199                            })
200                            .ok();
201                        this.entries.remove(&key);
202                        cx.notify();
203                    })
204                    .ok();
205                    break;
206                }
207            }
208        })
209        .detach();
210
211        entry
212    }
213
214    fn handle_agent_servers_updated(
215        &mut self,
216        store: Entity<AgentServerStore>,
217        _: &AgentServersUpdated,
218        cx: &mut Context<Self>,
219    ) {
220        let store = store.read(cx);
221        self.entries.retain(|key, _| match key {
222            Agent::NativeAgent => true,
223            Agent::Custom { id } => store.external_agents.contains_key(id),
224        });
225        cx.notify();
226    }
227
228    fn start_connection(
229        &self,
230        server: Rc<dyn AgentServer>,
231        cx: &mut Context<Self>,
232    ) -> (
233        Receiver<Option<String>>,
234        Task<Result<AgentConnectedState, LoadError>>,
235    ) {
236        let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
237
238        let agent_server_store = self.project.read(cx).agent_server_store().clone();
239        let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
240
241        let connect_task = server.connect(delegate, self.project.clone(), cx);
242        let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
243            Ok(connection) => cx.update(|cx| {
244                let history = connection
245                    .session_list(cx)
246                    .map(|session_list| cx.new(|cx| ThreadHistory::new(session_list, cx)));
247                Ok(AgentConnectedState {
248                    connection,
249                    history,
250                })
251            }),
252            Err(err) => match err.downcast::<LoadError>() {
253                Ok(load_error) => Err(load_error),
254                Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
255            },
256        });
257        (new_version_rx, connect_task)
258    }
259}