1use std::rc::Rc;
2
3use acp_thread::{AgentConnection, LoadError};
4use agent_servers::{AgentServer, AgentServerDelegate};
5use anyhow::Result;
6use collections::HashMap;
7use futures::{FutureExt, future::Shared};
8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
9use project::{AgentServerStore, AgentServersUpdated, Project};
10use watch::Receiver;
11
12use crate::{Agent, ThreadHistory};
13use project::ExternalAgentServerName;
14
15pub enum AgentConnectionEntry {
16 Connecting {
17 connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
18 },
19 Connected(AgentConnectedState),
20 Error {
21 error: LoadError,
22 },
23}
24
25#[derive(Clone)]
26pub struct AgentConnectedState {
27 pub connection: Rc<dyn AgentConnection>,
28 pub history: Entity<ThreadHistory>,
29}
30
31impl AgentConnectionEntry {
32 pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
33 match self {
34 AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
35 AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
36 AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
37 }
38 }
39
40 pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
41 match self {
42 AgentConnectionEntry::Connected(state) => Some(&state.history),
43 _ => None,
44 }
45 }
46}
47
48pub enum AgentConnectionEntryEvent {
49 NewVersionAvailable(SharedString),
50}
51
52impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
53
54pub struct AgentConnectionStore {
55 project: Entity<Project>,
56 entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
57 _subscriptions: Vec<Subscription>,
58}
59
60impl AgentConnectionStore {
61 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
62 let agent_server_store = project.read(cx).agent_server_store().clone();
63 let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
64 Self {
65 project,
66 entries: HashMap::default(),
67 _subscriptions: vec![subscription],
68 }
69 }
70
71 pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
72 self.entries.get(key)
73 }
74
75 pub fn request_connection(
76 &mut self,
77 key: Agent,
78 server: Rc<dyn AgentServer>,
79 cx: &mut Context<Self>,
80 ) -> Entity<AgentConnectionEntry> {
81 self.entries.get(&key).cloned().unwrap_or_else(|| {
82 let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
83 let connect_task = connect_task.shared();
84
85 let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
86 connect_task: connect_task.clone(),
87 });
88
89 self.entries.insert(key.clone(), entry.clone());
90
91 cx.spawn({
92 let key = key.clone();
93 let entry = entry.clone();
94 async move |this, cx| match connect_task.await {
95 Ok(connected_state) => {
96 entry.update(cx, |entry, cx| {
97 if let AgentConnectionEntry::Connecting { .. } = entry {
98 *entry = AgentConnectionEntry::Connected(connected_state);
99 cx.notify();
100 }
101 });
102 }
103 Err(error) => {
104 entry.update(cx, |entry, cx| {
105 if let AgentConnectionEntry::Connecting { .. } = entry {
106 *entry = AgentConnectionEntry::Error { error };
107 cx.notify();
108 }
109 });
110 this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
111 }
112 }
113 })
114 .detach();
115
116 cx.spawn({
117 let entry = entry.clone();
118 async move |this, cx| {
119 while let Ok(version) = new_version_rx.recv().await {
120 if let Some(version) = version {
121 entry.update(cx, |_entry, cx| {
122 cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
123 version.clone().into(),
124 ));
125 });
126 this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
127 }
128 }
129 }
130 })
131 .detach();
132
133 entry
134 })
135 }
136
137 fn handle_agent_servers_updated(
138 &mut self,
139 store: Entity<AgentServerStore>,
140 _: &AgentServersUpdated,
141 cx: &mut Context<Self>,
142 ) {
143 let store = store.read(cx);
144 self.entries.retain(|key, _| match key {
145 Agent::NativeAgent => true,
146 Agent::Custom { name } => store
147 .external_agents
148 .contains_key(&ExternalAgentServerName(name.clone())),
149 });
150 cx.notify();
151 }
152
153 fn start_connection(
154 &self,
155 server: Rc<dyn AgentServer>,
156 cx: &mut Context<Self>,
157 ) -> (
158 Receiver<Option<String>>,
159 Task<Result<AgentConnectedState, LoadError>>,
160 ) {
161 let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
162
163 let agent_server_store = self.project.read(cx).agent_server_store().clone();
164 let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
165
166 let connect_task = server.connect(delegate, cx);
167 let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
168 Ok(connection) => cx.update(|cx| {
169 let history = cx.new(|cx| ThreadHistory::new(connection.session_list(cx), cx));
170 Ok(AgentConnectedState {
171 connection,
172 history,
173 })
174 }),
175 Err(err) => match err.downcast::<LoadError>() {
176 Ok(load_error) => Err(load_error),
177 Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
178 },
179 });
180 (new_version_rx, connect_task)
181 }
182}