1use crate::{
2 example::Example,
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 example_progress: &ExampleProgress,
21 mut cx: AsyncApp,
22) -> anyhow::Result<()> {
23 if example
24 .prompt_inputs
25 .as_ref()
26 .is_some_and(|inputs| inputs.related_files.is_some())
27 {
28 return Ok(());
29 }
30
31 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
32
33 let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
34
35 let state = example.state.as_ref().unwrap();
36 let project = state.project.clone();
37
38 let _lsp_handle = project.update(&mut cx, |project, cx| {
39 project.register_buffer_with_language_servers(&state.buffer, cx)
40 });
41 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
42
43 let ep_store = cx
44 .update(|cx| EditPredictionStore::try_global(cx))
45 .context("EditPredictionStore not initialized")?;
46
47 let mut events = ep_store.update(&mut cx, |store, cx| {
48 store.register_buffer(&state.buffer, &project, cx);
49 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
50 store.debug_info(&project, cx)
51 });
52
53 while let Some(event) = events.next().await {
54 match event {
55 DebugEvent::ContextRetrievalFinished(_) => {
56 break;
57 }
58 _ => {}
59 }
60 }
61
62 let context_files =
63 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
64
65 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
66 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
67
68 if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
69 prompt_inputs.related_files = Some(context_files);
70 }
71 Ok(())
72}
73
74async fn wait_for_language_servers_to_start(
75 project: &Entity<Project>,
76 buffer: &Entity<Buffer>,
77 step_progress: &Arc<StepProgress>,
78 cx: &mut AsyncApp,
79) -> anyhow::Result<()> {
80 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
81
82 // Determine which servers exist for this buffer, and which are still starting.
83 let mut servers_pending_start = HashSet::default();
84 let mut servers_pending_diagnostics = HashSet::default();
85 buffer.update(cx, |buffer, cx| {
86 lsp_store.update(cx, |lsp_store, cx| {
87 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
88 for &id in &ids {
89 match lsp_store.language_server_statuses.get(&id) {
90 None => {
91 servers_pending_start.insert(id);
92 servers_pending_diagnostics.insert(id);
93 }
94 Some(status) if status.has_pending_diagnostic_updates => {
95 servers_pending_diagnostics.insert(id);
96 }
97 Some(_) => {}
98 }
99 }
100 });
101 });
102
103 step_progress.set_substatus(format!(
104 "waiting for {} LSPs",
105 servers_pending_diagnostics.len()
106 ));
107
108 let timeout_duration = if servers_pending_start.is_empty() {
109 Duration::from_secs(30)
110 } else {
111 Duration::from_secs(60 * 5)
112 };
113 let timeout = cx.background_executor().timer(timeout_duration).shared();
114
115 let (mut started_tx, mut started_rx) = mpsc::channel(servers_pending_start.len().max(1));
116 let (mut diag_tx, mut diag_rx) = mpsc::channel(servers_pending_diagnostics.len().max(1));
117 let subscriptions = [cx.subscribe(&lsp_store, {
118 let step_progress = step_progress.clone();
119 move |lsp_store, event, cx| match event {
120 project::LspStoreEvent::LanguageServerAdded(id, name, _) => {
121 step_progress.set_substatus(format!("LSP started: {}", name));
122 started_tx.try_send(*id).ok();
123 }
124 project::LspStoreEvent::DiskBasedDiagnosticsFinished { language_server_id } => {
125 let name = lsp_store
126 .read(cx)
127 .language_server_adapter_for_id(*language_server_id)
128 .unwrap()
129 .name();
130 step_progress.set_substatus(format!("LSP idle: {}", name));
131 diag_tx.try_send(*language_server_id).ok();
132 }
133 project::LspStoreEvent::LanguageServerUpdate {
134 message:
135 client::proto::update_language_server::Variant::WorkProgress(
136 client::proto::LspWorkProgress {
137 message: Some(message),
138 ..
139 },
140 ),
141 ..
142 } => {
143 step_progress.set_substatus(message.clone());
144 }
145 _ => {}
146 }
147 })];
148
149 // Phase 1: wait for all servers to start.
150 while !servers_pending_start.is_empty() {
151 futures::select! {
152 id = started_rx.next() => {
153 if let Some(id) = id {
154 servers_pending_start.remove(&id);
155 }
156 },
157 _ = timeout.clone().fuse() => {
158 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
159 }
160 }
161 }
162
163 // Save the buffer so the server sees the current content and kicks off diagnostics.
164 project
165 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
166 .await?;
167
168 // Phase 2: wait for all servers to finish their diagnostic pass.
169 while !servers_pending_diagnostics.is_empty() {
170 futures::select! {
171 id = diag_rx.next() => {
172 if let Some(id) = id {
173 servers_pending_diagnostics.remove(&id);
174 }
175 },
176 _ = timeout.clone().fuse() => {
177 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
178 }
179 }
180 }
181
182 drop(subscriptions);
183 step_progress.clear_substatus();
184 Ok(())
185}