agent_connection_store.rs

  1use std::rc::Rc;
  2
  3use acp_thread::{AgentConnection, LoadError};
  4use agent_servers::{AgentServer, AgentServerDelegate};
  5use anyhow::Result;
  6use collections::HashMap;
  7use futures::{FutureExt, future::Shared};
  8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
  9use project::{AgentServerStore, AgentServersUpdated, Project};
 10use watch::Receiver;
 11
 12use crate::{Agent, ThreadHistory};
 13
 14pub enum AgentConnectionEntry {
 15    Connecting {
 16        connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
 17    },
 18    Connected(AgentConnectedState),
 19    Error {
 20        error: LoadError,
 21    },
 22}
 23
 24#[derive(Clone)]
 25pub struct AgentConnectedState {
 26    pub connection: Rc<dyn AgentConnection>,
 27    pub history: Option<Entity<ThreadHistory>>,
 28}
 29
 30impl AgentConnectionEntry {
 31    pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
 32        match self {
 33            AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
 34            AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
 35            AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
 36        }
 37    }
 38
 39    pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
 40        match self {
 41            AgentConnectionEntry::Connected(state) => state.history.as_ref(),
 42            _ => None,
 43        }
 44    }
 45}
 46
 47pub enum AgentConnectionEntryEvent {
 48    NewVersionAvailable(SharedString),
 49}
 50
 51impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
 52
 53pub struct AgentConnectionStore {
 54    project: Entity<Project>,
 55    entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
 56    _subscriptions: Vec<Subscription>,
 57}
 58
 59impl AgentConnectionStore {
 60    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 61        let agent_server_store = project.read(cx).agent_server_store().clone();
 62        let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
 63        Self {
 64            project,
 65            entries: HashMap::default(),
 66            _subscriptions: vec![subscription],
 67        }
 68    }
 69
 70    pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
 71        self.entries.get(key)
 72    }
 73
 74    pub fn request_connection(
 75        &mut self,
 76        key: Agent,
 77        server: Rc<dyn AgentServer>,
 78        cx: &mut Context<Self>,
 79    ) -> Entity<AgentConnectionEntry> {
 80        self.entries.get(&key).cloned().unwrap_or_else(|| {
 81            let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
 82            let connect_task = connect_task.shared();
 83
 84            let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
 85                connect_task: connect_task.clone(),
 86            });
 87
 88            self.entries.insert(key.clone(), entry.clone());
 89
 90            cx.spawn({
 91                let key = key.clone();
 92                let entry = entry.clone();
 93                async move |this, cx| match connect_task.await {
 94                    Ok(connected_state) => {
 95                        entry.update(cx, |entry, cx| {
 96                            if let AgentConnectionEntry::Connecting { .. } = entry {
 97                                *entry = AgentConnectionEntry::Connected(connected_state);
 98                                cx.notify();
 99                            }
100                        });
101                    }
102                    Err(error) => {
103                        entry.update(cx, |entry, cx| {
104                            if let AgentConnectionEntry::Connecting { .. } = entry {
105                                *entry = AgentConnectionEntry::Error { error };
106                                cx.notify();
107                            }
108                        });
109                        this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
110                    }
111                }
112            })
113            .detach();
114
115            cx.spawn({
116                let entry = entry.clone();
117                async move |this, cx| {
118                    while let Ok(version) = new_version_rx.recv().await {
119                        if let Some(version) = version {
120                            entry.update(cx, |_entry, cx| {
121                                cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
122                                    version.clone().into(),
123                                ));
124                            });
125                            this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
126                        }
127                    }
128                }
129            })
130            .detach();
131
132            entry
133        })
134    }
135
136    fn handle_agent_servers_updated(
137        &mut self,
138        store: Entity<AgentServerStore>,
139        _: &AgentServersUpdated,
140        cx: &mut Context<Self>,
141    ) {
142        let store = store.read(cx);
143        self.entries.retain(|key, _| match key {
144            Agent::NativeAgent => true,
145            Agent::Custom { id } => store.external_agents.contains_key(id),
146        });
147        cx.notify();
148    }
149
150    fn start_connection(
151        &self,
152        server: Rc<dyn AgentServer>,
153        cx: &mut Context<Self>,
154    ) -> (
155        Receiver<Option<String>>,
156        Task<Result<AgentConnectedState, LoadError>>,
157    ) {
158        let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
159
160        let agent_server_store = self.project.read(cx).agent_server_store().clone();
161        let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
162
163        let connect_task = server.connect(delegate, self.project.clone(), cx);
164        let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
165            Ok(connection) => cx.update(|cx| {
166                let history = connection
167                    .session_list(cx)
168                    .map(|session_list| cx.new(|cx| ThreadHistory::new(session_list, cx)));
169                Ok(AgentConnectedState {
170                    connection,
171                    history,
172                })
173            }),
174            Err(err) => match err.downcast::<LoadError>() {
175                Ok(load_error) => Err(load_error),
176                Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
177            },
178        });
179        (new_version_rx, connect_task)
180    }
181}