1use std::rc::Rc;
2
3use acp_thread::{AgentConnection, LoadError};
4use agent_servers::{AgentServer, AgentServerDelegate};
5use anyhow::Result;
6use collections::HashMap;
7use futures::{FutureExt, future::Shared};
8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
9use project::{AgentServerStore, AgentServersUpdated, Project};
10use watch::Receiver;
11
12use crate::ExternalAgent;
13use project::ExternalAgentServerName;
14
15pub enum ConnectionEntry {
16 Connecting {
17 connect_task: Shared<Task<Result<Rc<dyn AgentConnection>, LoadError>>>,
18 },
19 Connected {
20 connection: Rc<dyn AgentConnection>,
21 },
22 Error {
23 error: LoadError,
24 },
25}
26
27impl ConnectionEntry {
28 pub fn wait_for_connection(&self) -> Shared<Task<Result<Rc<dyn AgentConnection>, LoadError>>> {
29 match self {
30 ConnectionEntry::Connecting { connect_task } => connect_task.clone(),
31 ConnectionEntry::Connected { connection } => {
32 Task::ready(Ok(connection.clone())).shared()
33 }
34 ConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
35 }
36 }
37}
38
39pub enum ConnectionEntryEvent {
40 NewVersionAvailable(SharedString),
41}
42
43impl EventEmitter<ConnectionEntryEvent> for ConnectionEntry {}
44
45pub struct AgentConnectionStore {
46 project: Entity<Project>,
47 entries: HashMap<ExternalAgent, Entity<ConnectionEntry>>,
48 _subscriptions: Vec<Subscription>,
49}
50
51impl AgentConnectionStore {
52 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
53 let agent_server_store = project.read(cx).agent_server_store().clone();
54 let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
55 Self {
56 project,
57 entries: HashMap::default(),
58 _subscriptions: vec![subscription],
59 }
60 }
61
62 pub fn request_connection(
63 &mut self,
64 key: ExternalAgent,
65 server: Rc<dyn AgentServer>,
66 cx: &mut Context<Self>,
67 ) -> Entity<ConnectionEntry> {
68 self.entries.get(&key).cloned().unwrap_or_else(|| {
69 let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
70 let connect_task = connect_task.shared();
71
72 let entry = cx.new(|_cx| ConnectionEntry::Connecting {
73 connect_task: connect_task.clone(),
74 });
75
76 self.entries.insert(key.clone(), entry.clone());
77
78 cx.spawn({
79 let key = key.clone();
80 let entry = entry.clone();
81 async move |this, cx| match connect_task.await {
82 Ok(connection) => {
83 entry.update(cx, |entry, cx| {
84 if let ConnectionEntry::Connecting { .. } = entry {
85 *entry = ConnectionEntry::Connected { connection };
86 cx.notify();
87 }
88 });
89 }
90 Err(error) => {
91 entry.update(cx, |entry, cx| {
92 if let ConnectionEntry::Connecting { .. } = entry {
93 *entry = ConnectionEntry::Error { error };
94 cx.notify();
95 }
96 });
97 this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
98 }
99 }
100 })
101 .detach();
102
103 cx.spawn({
104 let entry = entry.clone();
105 async move |this, cx| {
106 while let Ok(version) = new_version_rx.recv().await {
107 if let Some(version) = version {
108 entry.update(cx, |_entry, cx| {
109 cx.emit(ConnectionEntryEvent::NewVersionAvailable(
110 version.clone().into(),
111 ));
112 });
113 this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
114 }
115 }
116 }
117 })
118 .detach();
119
120 entry
121 })
122 }
123
124 fn handle_agent_servers_updated(
125 &mut self,
126 store: Entity<AgentServerStore>,
127 _: &AgentServersUpdated,
128 cx: &mut Context<Self>,
129 ) {
130 let store = store.read(cx);
131 self.entries.retain(|key, _| match key {
132 ExternalAgent::NativeAgent => true,
133 ExternalAgent::Custom { name } => store
134 .external_agents
135 .contains_key(&ExternalAgentServerName(name.clone())),
136 });
137 cx.notify();
138 }
139
140 fn start_connection(
141 &self,
142 server: Rc<dyn AgentServer>,
143 cx: &mut Context<Self>,
144 ) -> (
145 Receiver<Option<String>>,
146 Task<Result<Rc<dyn AgentConnection>, LoadError>>,
147 ) {
148 let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
149
150 let agent_server_store = self.project.read(cx).agent_server_store().clone();
151 let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
152
153 let connect_task = server.connect(delegate, cx);
154 let connect_task = cx.spawn(async move |_this, _cx| match connect_task.await {
155 Ok(connection) => Ok(connection),
156 Err(err) => match err.downcast::<LoadError>() {
157 Ok(load_error) => Err(load_error),
158 Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
159 },
160 });
161 (new_version_rx, connect_task)
162 }
163}