agent_connection_store.rs

  1use std::rc::Rc;
  2
  3use acp_thread::{AgentConnection, LoadError};
  4use agent_servers::{AgentServer, AgentServerDelegate};
  5use anyhow::Result;
  6use collections::HashMap;
  7use futures::{FutureExt, future::Shared};
  8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
  9use project::{AgentServerStore, AgentServersUpdated, Project};
 10use watch::Receiver;
 11
 12use crate::{Agent, ThreadHistory};
 13use project::ExternalAgentServerName;
 14
 15pub enum AgentConnectionEntry {
 16    Connecting {
 17        connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
 18    },
 19    Connected(AgentConnectedState),
 20    Error {
 21        error: LoadError,
 22    },
 23}
 24
 25#[derive(Clone)]
 26pub struct AgentConnectedState {
 27    pub connection: Rc<dyn AgentConnection>,
 28    pub history: Entity<ThreadHistory>,
 29}
 30
 31impl AgentConnectionEntry {
 32    pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
 33        match self {
 34            AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
 35            AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
 36            AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
 37        }
 38    }
 39
 40    pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
 41        match self {
 42            AgentConnectionEntry::Connected(state) => Some(&state.history),
 43            _ => None,
 44        }
 45    }
 46}
 47
 48pub enum AgentConnectionEntryEvent {
 49    NewVersionAvailable(SharedString),
 50}
 51
 52impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
 53
 54pub struct AgentConnectionStore {
 55    project: Entity<Project>,
 56    entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
 57    _subscriptions: Vec<Subscription>,
 58}
 59
 60impl AgentConnectionStore {
 61    pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 62        let agent_server_store = project.read(cx).agent_server_store().clone();
 63        let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
 64        Self {
 65            project,
 66            entries: HashMap::default(),
 67            _subscriptions: vec![subscription],
 68        }
 69    }
 70
 71    pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
 72        self.entries.get(key)
 73    }
 74
 75    pub fn request_connection(
 76        &mut self,
 77        key: Agent,
 78        server: Rc<dyn AgentServer>,
 79        cx: &mut Context<Self>,
 80    ) -> Entity<AgentConnectionEntry> {
 81        self.entries.get(&key).cloned().unwrap_or_else(|| {
 82            let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
 83            let connect_task = connect_task.shared();
 84
 85            let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
 86                connect_task: connect_task.clone(),
 87            });
 88
 89            self.entries.insert(key.clone(), entry.clone());
 90
 91            cx.spawn({
 92                let key = key.clone();
 93                let entry = entry.clone();
 94                async move |this, cx| match connect_task.await {
 95                    Ok(connected_state) => {
 96                        entry.update(cx, |entry, cx| {
 97                            if let AgentConnectionEntry::Connecting { .. } = entry {
 98                                *entry = AgentConnectionEntry::Connected(connected_state);
 99                                cx.notify();
100                            }
101                        });
102                    }
103                    Err(error) => {
104                        entry.update(cx, |entry, cx| {
105                            if let AgentConnectionEntry::Connecting { .. } = entry {
106                                *entry = AgentConnectionEntry::Error { error };
107                                cx.notify();
108                            }
109                        });
110                        this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
111                    }
112                }
113            })
114            .detach();
115
116            cx.spawn({
117                let entry = entry.clone();
118                async move |this, cx| {
119                    while let Ok(version) = new_version_rx.recv().await {
120                        if let Some(version) = version {
121                            entry.update(cx, |_entry, cx| {
122                                cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
123                                    version.clone().into(),
124                                ));
125                            });
126                            this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
127                        }
128                    }
129                }
130            })
131            .detach();
132
133            entry
134        })
135    }
136
137    fn handle_agent_servers_updated(
138        &mut self,
139        store: Entity<AgentServerStore>,
140        _: &AgentServersUpdated,
141        cx: &mut Context<Self>,
142    ) {
143        let store = store.read(cx);
144        self.entries.retain(|key, _| match key {
145            Agent::NativeAgent => true,
146            Agent::Custom { name } => store
147                .external_agents
148                .contains_key(&ExternalAgentServerName(name.clone())),
149        });
150        cx.notify();
151    }
152
153    fn start_connection(
154        &self,
155        server: Rc<dyn AgentServer>,
156        cx: &mut Context<Self>,
157    ) -> (
158        Receiver<Option<String>>,
159        Task<Result<AgentConnectedState, LoadError>>,
160    ) {
161        let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
162
163        let agent_server_store = self.project.read(cx).agent_server_store().clone();
164        let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
165
166        let connect_task = server.connect(delegate, cx);
167        let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
168            Ok(connection) => cx.update(|cx| {
169                let history = cx.new(|cx| ThreadHistory::new(connection.session_list(cx), cx));
170                Ok(AgentConnectedState {
171                    connection,
172                    history,
173                })
174            }),
175            Err(err) => match err.downcast::<LoadError>() {
176                Ok(load_error) => Err(load_error),
177                Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
178            },
179        });
180        (new_version_rx, connect_task)
181    }
182}