agent_connection_store.rs

  1use std::rc::Rc;
  2
  3use acp_thread::{AgentConnection, LoadError};
  4use agent_servers::{AgentServer, AgentServerDelegate};
  5use anyhow::Result;
  6use collections::HashMap;
  7use futures::{FutureExt, future::Shared};
  8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
  9use project::{AgentServerStore, AgentServersUpdated, Project};
 10use watch::Receiver;
 11
 12use crate::ExternalAgent;
 13use project::ExternalAgentServerName;
 14
 15pub enum ConnectionEntry {
 16    Connecting {
 17        connect_task: Shared<Task<Result<Rc<dyn AgentConnection>, LoadError>>>,
 18    },
 19    Connected {
 20        connection: Rc<dyn AgentConnection>,
 21    },
 22    Error {
 23        error: LoadError,
 24    },
 25}
 26
 27impl ConnectionEntry {
 28    pub fn wait_for_connection(&self) -> Shared<Task<Result<Rc<dyn AgentConnection>, LoadError>>> {
 29        match self {
 30            ConnectionEntry::Connecting { connect_task } => connect_task.clone(),
 31            ConnectionEntry::Connected { connection } => {
 32                Task::ready(Ok(connection.clone())).shared()
 33            }
 34            ConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
 35        }
 36    }
 37}
 38
 39pub enum ConnectionEntryEvent {
 40    NewVersionAvailable(SharedString),
 41}
 42
 43impl EventEmitter<ConnectionEntryEvent> for ConnectionEntry {}
 44
 45pub struct AgentConnectionStore {
 46    project: Entity<Project>,
 47    entries: HashMap<ExternalAgent, Entity<ConnectionEntry>>,
 48    _subscriptions: Vec<Subscription>,
 49}
 50
 51impl AgentConnectionStore {
 52    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 53        let agent_server_store = project.read(cx).agent_server_store().clone();
 54        let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
 55        Self {
 56            project,
 57            entries: HashMap::default(),
 58            _subscriptions: vec![subscription],
 59        }
 60    }
 61
 62    pub fn request_connection(
 63        &mut self,
 64        key: ExternalAgent,
 65        server: Rc<dyn AgentServer>,
 66        cx: &mut Context<Self>,
 67    ) -> Entity<ConnectionEntry> {
 68        self.entries.get(&key).cloned().unwrap_or_else(|| {
 69            let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
 70            let connect_task = connect_task.shared();
 71
 72            let entry = cx.new(|_cx| ConnectionEntry::Connecting {
 73                connect_task: connect_task.clone(),
 74            });
 75
 76            self.entries.insert(key.clone(), entry.clone());
 77
 78            cx.spawn({
 79                let key = key.clone();
 80                let entry = entry.clone();
 81                async move |this, cx| match connect_task.await {
 82                    Ok(connection) => {
 83                        entry.update(cx, |entry, cx| {
 84                            if let ConnectionEntry::Connecting { .. } = entry {
 85                                *entry = ConnectionEntry::Connected { connection };
 86                                cx.notify();
 87                            }
 88                        });
 89                    }
 90                    Err(error) => {
 91                        entry.update(cx, |entry, cx| {
 92                            if let ConnectionEntry::Connecting { .. } = entry {
 93                                *entry = ConnectionEntry::Error { error };
 94                                cx.notify();
 95                            }
 96                        });
 97                        this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
 98                    }
 99                }
100            })
101            .detach();
102
103            cx.spawn({
104                let entry = entry.clone();
105                async move |this, cx| {
106                    while let Ok(version) = new_version_rx.recv().await {
107                        if let Some(version) = version {
108                            entry.update(cx, |_entry, cx| {
109                                cx.emit(ConnectionEntryEvent::NewVersionAvailable(
110                                    version.clone().into(),
111                                ));
112                            });
113                            this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
114                        }
115                    }
116                }
117            })
118            .detach();
119
120            entry
121        })
122    }
123
124    fn handle_agent_servers_updated(
125        &mut self,
126        store: Entity<AgentServerStore>,
127        _: &AgentServersUpdated,
128        cx: &mut Context<Self>,
129    ) {
130        let store = store.read(cx);
131        self.entries.retain(|key, _| match key {
132            ExternalAgent::NativeAgent => true,
133            ExternalAgent::Custom { name } => store
134                .external_agents
135                .contains_key(&ExternalAgentServerName(name.clone())),
136        });
137        cx.notify();
138    }
139
140    fn start_connection(
141        &self,
142        server: Rc<dyn AgentServer>,
143        cx: &mut Context<Self>,
144    ) -> (
145        Receiver<Option<String>>,
146        Task<Result<Rc<dyn AgentConnection>, LoadError>>,
147    ) {
148        let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
149
150        let agent_server_store = self.project.read(cx).agent_server_store().clone();
151        let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
152
153        let connect_task = server.connect(delegate, cx);
154        let connect_task = cx.spawn(async move |_this, _cx| match connect_task.await {
155            Ok(connection) => Ok(connection),
156            Err(err) => match err.downcast::<LoadError>() {
157                Ok(load_error) => Err(load_error),
158                Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
159            },
160        });
161        (new_version_rx, connect_task)
162    }
163}