1use std::rc::Rc;
2
3use acp_thread::{AgentConnection, LoadError};
4use agent_servers::{AgentServer, AgentServerDelegate};
5use anyhow::Result;
6use collections::HashMap;
7use futures::{FutureExt, future::Shared};
8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
9use project::{AgentServerStore, AgentServersUpdated, Project};
10use watch::Receiver;
11
12use crate::{Agent, ThreadHistory};
13
14pub enum AgentConnectionEntry {
15 Connecting {
16 connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
17 },
18 Connected(AgentConnectedState),
19 Error {
20 error: LoadError,
21 },
22}
23
24#[derive(Clone)]
25pub struct AgentConnectedState {
26 pub connection: Rc<dyn AgentConnection>,
27 pub history: Option<Entity<ThreadHistory>>,
28}
29
30impl AgentConnectionEntry {
31 pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
32 match self {
33 AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
34 AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
35 AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
36 }
37 }
38
39 pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
40 match self {
41 AgentConnectionEntry::Connected(state) => state.history.as_ref(),
42 _ => None,
43 }
44 }
45}
46
47pub enum AgentConnectionEntryEvent {
48 NewVersionAvailable(SharedString),
49}
50
51impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
52
53pub struct AgentConnectionStore {
54 project: Entity<Project>,
55 entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
56 _subscriptions: Vec<Subscription>,
57}
58
59impl AgentConnectionStore {
60 pub fn new(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
61 let agent_server_store = project.read(cx).agent_server_store().clone();
62 let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
63 Self {
64 project,
65 entries: HashMap::default(),
66 _subscriptions: vec![subscription],
67 }
68 }
69
70 pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
71 self.entries.get(key)
72 }
73
74 pub fn request_connection(
75 &mut self,
76 key: Agent,
77 server: Rc<dyn AgentServer>,
78 cx: &mut Context<Self>,
79 ) -> Entity<AgentConnectionEntry> {
80 self.entries.get(&key).cloned().unwrap_or_else(|| {
81 let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
82 let connect_task = connect_task.shared();
83
84 let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
85 connect_task: connect_task.clone(),
86 });
87
88 self.entries.insert(key.clone(), entry.clone());
89
90 cx.spawn({
91 let key = key.clone();
92 let entry = entry.clone();
93 async move |this, cx| match connect_task.await {
94 Ok(connected_state) => {
95 entry.update(cx, |entry, cx| {
96 if let AgentConnectionEntry::Connecting { .. } = entry {
97 *entry = AgentConnectionEntry::Connected(connected_state);
98 cx.notify();
99 }
100 });
101 }
102 Err(error) => {
103 entry.update(cx, |entry, cx| {
104 if let AgentConnectionEntry::Connecting { .. } = entry {
105 *entry = AgentConnectionEntry::Error { error };
106 cx.notify();
107 }
108 });
109 this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
110 }
111 }
112 })
113 .detach();
114
115 cx.spawn({
116 let entry = entry.clone();
117 async move |this, cx| {
118 while let Ok(version) = new_version_rx.recv().await {
119 if let Some(version) = version {
120 entry.update(cx, |_entry, cx| {
121 cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
122 version.clone().into(),
123 ));
124 });
125 this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
126 }
127 }
128 }
129 })
130 .detach();
131
132 entry
133 })
134 }
135
136 fn handle_agent_servers_updated(
137 &mut self,
138 store: Entity<AgentServerStore>,
139 _: &AgentServersUpdated,
140 cx: &mut Context<Self>,
141 ) {
142 let store = store.read(cx);
143 self.entries.retain(|key, _| match key {
144 Agent::NativeAgent => true,
145 Agent::Custom { id } => store.external_agents.contains_key(id),
146 });
147 cx.notify();
148 }
149
150 fn start_connection(
151 &self,
152 server: Rc<dyn AgentServer>,
153 cx: &mut Context<Self>,
154 ) -> (
155 Receiver<Option<String>>,
156 Task<Result<AgentConnectedState, LoadError>>,
157 ) {
158 let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
159
160 let agent_server_store = self.project.read(cx).agent_server_store().clone();
161 let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
162
163 let connect_task = server.connect(delegate, self.project.clone(), cx);
164 let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
165 Ok(connection) => cx.update(|cx| {
166 let history = connection
167 .session_list(cx)
168 .map(|session_list| cx.new(|cx| ThreadHistory::new(session_list, cx)));
169 Ok(AgentConnectedState {
170 connection,
171 history,
172 })
173 }),
174 Err(err) => match err.downcast::<LoadError>() {
175 Ok(load_error) => Err(load_error),
176 Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
177 },
178 });
179 (new_version_rx, connect_task)
180 }
181}