agent_connection_store.rs

  1use std::rc::Rc;
  2
  3use acp_thread::{AgentConnection, LoadError};
  4use agent::ThreadStore;
  5use agent_servers::{AgentServer, AgentServerDelegate};
  6use anyhow::Result;
  7use collections::HashMap;
  8use futures::{FutureExt, future::Shared};
  9use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
 10use project::{AgentServerStore, AgentServersUpdated, Project};
 11use watch::Receiver;
 12
 13use crate::{Agent, ThreadHistory};
 14use project::ExternalAgentServerName;
 15
 16pub enum AgentConnectionEntry {
 17    Connecting {
 18        connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
 19    },
 20    Connected(AgentConnectedState),
 21    Error {
 22        error: LoadError,
 23    },
 24}
 25
 26#[derive(Clone)]
 27pub struct AgentConnectedState {
 28    pub connection: Rc<dyn AgentConnection>,
 29    pub history: Entity<ThreadHistory>,
 30}
 31
 32impl AgentConnectionEntry {
 33    pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
 34        match self {
 35            AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
 36            AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
 37            AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
 38        }
 39    }
 40
 41    pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
 42        match self {
 43            AgentConnectionEntry::Connected(state) => Some(&state.history),
 44            _ => None,
 45        }
 46    }
 47}
 48
 49pub enum AgentConnectionEntryEvent {
 50    NewVersionAvailable(SharedString),
 51}
 52
 53impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
 54
 55pub struct AgentConnectionStore {
 56    project: Entity<Project>,
 57    thread_store: Entity<ThreadStore>,
 58    entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
 59    _subscriptions: Vec<Subscription>,
 60}
 61
 62impl AgentConnectionStore {
 63    pub fn new(
 64        project: Entity<Project>,
 65        thread_store: Entity<ThreadStore>,
 66        cx: &mut Context<Self>,
 67    ) -> Self {
 68        let agent_server_store = project.read(cx).agent_server_store().clone();
 69        let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
 70        Self {
 71            project,
 72            thread_store,
 73            entries: HashMap::default(),
 74            _subscriptions: vec![subscription],
 75        }
 76    }
 77
 78    pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
 79        self.entries.get(key)
 80    }
 81
 82    pub fn request_connection(
 83        &mut self,
 84        agent: Agent,
 85        cx: &mut Context<Self>,
 86    ) -> Entity<AgentConnectionEntry> {
 87        self.entries.get(&agent).cloned().unwrap_or_else(|| {
 88            let server = agent.server(
 89                self.project.read(cx).fs().clone(),
 90                self.thread_store.clone(),
 91            );
 92
 93            let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
 94            let connect_task = connect_task.shared();
 95
 96            let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
 97                connect_task: connect_task.clone(),
 98            });
 99
100            self.entries.insert(agent.clone(), entry.clone());
101
102            cx.spawn({
103                let key = agent.clone();
104                let entry = entry.clone();
105                async move |this, cx| match connect_task.await {
106                    Ok(connected_state) => {
107                        entry.update(cx, |entry, cx| {
108                            if let AgentConnectionEntry::Connecting { .. } = entry {
109                                *entry = AgentConnectionEntry::Connected(connected_state);
110                                cx.notify();
111                            }
112                        });
113                    }
114                    Err(error) => {
115                        entry.update(cx, |entry, cx| {
116                            if let AgentConnectionEntry::Connecting { .. } = entry {
117                                *entry = AgentConnectionEntry::Error { error };
118                                cx.notify();
119                            }
120                        });
121                        this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
122                    }
123                }
124            })
125            .detach();
126
127            cx.spawn({
128                let entry = entry.clone();
129                async move |this, cx| {
130                    while let Ok(version) = new_version_rx.recv().await {
131                        if let Some(version) = version {
132                            entry.update(cx, |_entry, cx| {
133                                cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
134                                    version.clone().into(),
135                                ));
136                            });
137                            this.update(cx, |this, _cx| this.entries.remove(&agent))
138                                .ok();
139                        }
140                    }
141                }
142            })
143            .detach();
144
145            entry
146        })
147    }
148
149    fn handle_agent_servers_updated(
150        &mut self,
151        store: Entity<AgentServerStore>,
152        _: &AgentServersUpdated,
153        cx: &mut Context<Self>,
154    ) {
155        let store = store.read(cx);
156        self.entries.retain(|key, _| match key {
157            Agent::NativeAgent => true,
158            Agent::Custom { name } => store
159                .external_agents
160                .contains_key(&ExternalAgentServerName(name.clone())),
161        });
162        cx.notify();
163    }
164
165    fn start_connection(
166        &self,
167        server: Rc<dyn AgentServer>,
168        cx: &mut Context<Self>,
169    ) -> (
170        Receiver<Option<String>>,
171        Task<Result<AgentConnectedState, LoadError>>,
172    ) {
173        let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
174
175        let agent_server_store = self.project.read(cx).agent_server_store().clone();
176        let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
177
178        let connect_task = server.connect(delegate, cx);
179        let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
180            Ok(connection) => cx.update(|cx| {
181                let history = cx.new(|cx| ThreadHistory::new(connection.session_list(cx), cx));
182                Ok(AgentConnectedState {
183                    connection,
184                    history,
185                })
186            }),
187            Err(err) => match err.downcast::<LoadError>() {
188                Ok(load_error) => Err(load_error),
189                Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
190            },
191        });
192        (new_version_rx, connect_task)
193    }
194}