1use crate::{
2 example::Example,
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 example_progress: &ExampleProgress,
21 mut cx: AsyncApp,
22) -> anyhow::Result<()> {
23 if example
24 .prompt_inputs
25 .as_ref()
26 .is_some_and(|inputs| inputs.related_files.is_some())
27 || example.spec.repository_url.is_empty()
28 {
29 return Ok(());
30 }
31
32 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
33
34 let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
35
36 let state = example.state.as_ref().unwrap();
37 let project = state.project.clone();
38
39 let _lsp_handle = project.update(&mut cx, |project, cx| {
40 project.register_buffer_with_language_servers(&state.buffer, cx)
41 });
42 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
43
44 let ep_store = cx
45 .update(|cx| EditPredictionStore::try_global(cx))
46 .context("EditPredictionStore not initialized")?;
47
48 let mut events = ep_store.update(&mut cx, |store, cx| {
49 store.register_buffer(&state.buffer, &project, cx);
50 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
51 store.debug_info(&project, cx)
52 });
53
54 while let Some(event) = events.next().await {
55 match event {
56 DebugEvent::ContextRetrievalFinished(_) => {
57 break;
58 }
59 _ => {}
60 }
61 }
62
63 let context_files =
64 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
65
66 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
67 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
68
69 if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
70 prompt_inputs.related_files = Some(context_files);
71 }
72 Ok(())
73}
74
75async fn wait_for_language_servers_to_start(
76 project: &Entity<Project>,
77 buffer: &Entity<Buffer>,
78 step_progress: &Arc<StepProgress>,
79 cx: &mut AsyncApp,
80) -> anyhow::Result<()> {
81 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
82
83 // Determine which servers exist for this buffer, and which are still starting.
84 let mut servers_pending_start = HashSet::default();
85 let mut servers_pending_diagnostics = HashSet::default();
86 buffer.update(cx, |buffer, cx| {
87 lsp_store.update(cx, |lsp_store, cx| {
88 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
89 for &id in &ids {
90 match lsp_store.language_server_statuses.get(&id) {
91 None => {
92 servers_pending_start.insert(id);
93 servers_pending_diagnostics.insert(id);
94 }
95 Some(status) if status.has_pending_diagnostic_updates => {
96 servers_pending_diagnostics.insert(id);
97 }
98 Some(_) => {}
99 }
100 }
101 });
102 });
103
104 step_progress.set_substatus(format!(
105 "waiting for {} LSPs",
106 servers_pending_diagnostics.len()
107 ));
108
109 let timeout_duration = if servers_pending_start.is_empty() {
110 Duration::from_secs(30)
111 } else {
112 Duration::from_secs(60 * 5)
113 };
114 let timeout = cx.background_executor().timer(timeout_duration).shared();
115
116 let (mut started_tx, mut started_rx) = mpsc::channel(servers_pending_start.len().max(1));
117 let (mut diag_tx, mut diag_rx) = mpsc::channel(servers_pending_diagnostics.len().max(1));
118 let subscriptions = [cx.subscribe(&lsp_store, {
119 let step_progress = step_progress.clone();
120 move |lsp_store, event, cx| match event {
121 project::LspStoreEvent::LanguageServerAdded(id, name, _) => {
122 step_progress.set_substatus(format!("LSP started: {}", name));
123 started_tx.try_send(*id).ok();
124 }
125 project::LspStoreEvent::DiskBasedDiagnosticsFinished { language_server_id } => {
126 let name = lsp_store
127 .read(cx)
128 .language_server_adapter_for_id(*language_server_id)
129 .unwrap()
130 .name();
131 step_progress.set_substatus(format!("LSP idle: {}", name));
132 diag_tx.try_send(*language_server_id).ok();
133 }
134 project::LspStoreEvent::LanguageServerUpdate {
135 message:
136 client::proto::update_language_server::Variant::WorkProgress(
137 client::proto::LspWorkProgress {
138 message: Some(message),
139 ..
140 },
141 ),
142 ..
143 } => {
144 step_progress.set_substatus(message.clone());
145 }
146 _ => {}
147 }
148 })];
149
150 // Phase 1: wait for all servers to start.
151 while !servers_pending_start.is_empty() {
152 futures::select! {
153 id = started_rx.next() => {
154 if let Some(id) = id {
155 servers_pending_start.remove(&id);
156 }
157 },
158 _ = timeout.clone().fuse() => {
159 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
160 }
161 }
162 }
163
164 // Save the buffer so the server sees the current content and kicks off diagnostics.
165 project
166 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
167 .await?;
168
169 // Phase 2: wait for all servers to finish their diagnostic pass.
170 while !servers_pending_diagnostics.is_empty() {
171 futures::select! {
172 id = diag_rx.next() => {
173 if let Some(id) = id {
174 servers_pending_diagnostics.remove(&id);
175 }
176 },
177 _ = timeout.clone().fuse() => {
178 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
179 }
180 }
181 }
182
183 drop(subscriptions);
184 step_progress.clear_substatus();
185 Ok(())
186}