1use std::rc::Rc;
2
3use acp_thread::{AgentConnection, LoadError};
4use agent::ThreadStore;
5use agent_servers::{AgentServer, AgentServerDelegate};
6use anyhow::Result;
7use collections::HashMap;
8use futures::{FutureExt, future::Shared};
9use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task};
10use project::{AgentServerStore, AgentServersUpdated, Project};
11use watch::Receiver;
12
13use crate::{Agent, ThreadHistory};
14use project::ExternalAgentServerName;
15
16pub enum AgentConnectionEntry {
17 Connecting {
18 connect_task: Shared<Task<Result<AgentConnectedState, LoadError>>>,
19 },
20 Connected(AgentConnectedState),
21 Error {
22 error: LoadError,
23 },
24}
25
26#[derive(Clone)]
27pub struct AgentConnectedState {
28 pub connection: Rc<dyn AgentConnection>,
29 pub history: Entity<ThreadHistory>,
30}
31
32impl AgentConnectionEntry {
33 pub fn wait_for_connection(&self) -> Shared<Task<Result<AgentConnectedState, LoadError>>> {
34 match self {
35 AgentConnectionEntry::Connecting { connect_task } => connect_task.clone(),
36 AgentConnectionEntry::Connected(state) => Task::ready(Ok(state.clone())).shared(),
37 AgentConnectionEntry::Error { error } => Task::ready(Err(error.clone())).shared(),
38 }
39 }
40
41 pub fn history(&self) -> Option<&Entity<ThreadHistory>> {
42 match self {
43 AgentConnectionEntry::Connected(state) => Some(&state.history),
44 _ => None,
45 }
46 }
47}
48
49pub enum AgentConnectionEntryEvent {
50 NewVersionAvailable(SharedString),
51}
52
53impl EventEmitter<AgentConnectionEntryEvent> for AgentConnectionEntry {}
54
55pub struct AgentConnectionStore {
56 project: Entity<Project>,
57 thread_store: Entity<ThreadStore>,
58 entries: HashMap<Agent, Entity<AgentConnectionEntry>>,
59 _subscriptions: Vec<Subscription>,
60}
61
62impl AgentConnectionStore {
63 pub fn new(
64 project: Entity<Project>,
65 thread_store: Entity<ThreadStore>,
66 cx: &mut Context<Self>,
67 ) -> Self {
68 let agent_server_store = project.read(cx).agent_server_store().clone();
69 let subscription = cx.subscribe(&agent_server_store, Self::handle_agent_servers_updated);
70 Self {
71 project,
72 thread_store,
73 entries: HashMap::default(),
74 _subscriptions: vec![subscription],
75 }
76 }
77
78 pub fn entry(&self, key: &Agent) -> Option<&Entity<AgentConnectionEntry>> {
79 self.entries.get(key)
80 }
81
82 pub fn request_connection(
83 &mut self,
84 agent: Agent,
85 cx: &mut Context<Self>,
86 ) -> Entity<AgentConnectionEntry> {
87 self.entries.get(&agent).cloned().unwrap_or_else(|| {
88 let server = agent.server(
89 self.project.read(cx).fs().clone(),
90 self.thread_store.clone(),
91 );
92
93 let (mut new_version_rx, connect_task) = self.start_connection(server.clone(), cx);
94 let connect_task = connect_task.shared();
95
96 let entry = cx.new(|_cx| AgentConnectionEntry::Connecting {
97 connect_task: connect_task.clone(),
98 });
99
100 self.entries.insert(agent.clone(), entry.clone());
101
102 cx.spawn({
103 let key = agent.clone();
104 let entry = entry.clone();
105 async move |this, cx| match connect_task.await {
106 Ok(connected_state) => {
107 entry.update(cx, |entry, cx| {
108 if let AgentConnectionEntry::Connecting { .. } = entry {
109 *entry = AgentConnectionEntry::Connected(connected_state);
110 cx.notify();
111 }
112 });
113 }
114 Err(error) => {
115 entry.update(cx, |entry, cx| {
116 if let AgentConnectionEntry::Connecting { .. } = entry {
117 *entry = AgentConnectionEntry::Error { error };
118 cx.notify();
119 }
120 });
121 this.update(cx, |this, _cx| this.entries.remove(&key)).ok();
122 }
123 }
124 })
125 .detach();
126
127 cx.spawn({
128 let entry = entry.clone();
129 async move |this, cx| {
130 while let Ok(version) = new_version_rx.recv().await {
131 if let Some(version) = version {
132 entry.update(cx, |_entry, cx| {
133 cx.emit(AgentConnectionEntryEvent::NewVersionAvailable(
134 version.clone().into(),
135 ));
136 });
137 this.update(cx, |this, _cx| this.entries.remove(&agent))
138 .ok();
139 }
140 }
141 }
142 })
143 .detach();
144
145 entry
146 })
147 }
148
149 fn handle_agent_servers_updated(
150 &mut self,
151 store: Entity<AgentServerStore>,
152 _: &AgentServersUpdated,
153 cx: &mut Context<Self>,
154 ) {
155 let store = store.read(cx);
156 self.entries.retain(|key, _| match key {
157 Agent::NativeAgent => true,
158 Agent::Custom { name } => store
159 .external_agents
160 .contains_key(&ExternalAgentServerName(name.clone())),
161 });
162 cx.notify();
163 }
164
165 fn start_connection(
166 &self,
167 server: Rc<dyn AgentServer>,
168 cx: &mut Context<Self>,
169 ) -> (
170 Receiver<Option<String>>,
171 Task<Result<AgentConnectedState, LoadError>>,
172 ) {
173 let (new_version_tx, new_version_rx) = watch::channel::<Option<String>>(None);
174
175 let agent_server_store = self.project.read(cx).agent_server_store().clone();
176 let delegate = AgentServerDelegate::new(agent_server_store, Some(new_version_tx));
177
178 let connect_task = server.connect(delegate, cx);
179 let connect_task = cx.spawn(async move |_this, cx| match connect_task.await {
180 Ok(connection) => cx.update(|cx| {
181 let history = cx.new(|cx| ThreadHistory::new(connection.session_list(cx), cx));
182 Ok(AgentConnectedState {
183 connection,
184 history,
185 })
186 }),
187 Err(err) => match err.downcast::<LoadError>() {
188 Ok(load_error) => Err(load_error),
189 Err(err) => Err(LoadError::Other(SharedString::from(err.to_string()))),
190 },
191 });
192 (new_version_rx, connect_task)
193 }
194}