1use crate::{
2 example::Example,
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 example_progress: &ExampleProgress,
21 mut cx: AsyncApp,
22) -> anyhow::Result<()> {
23 if example.prompt_inputs.is_some() {
24 if example.spec.repository_url.is_empty() {
25 return Ok(());
26 }
27
28 if example
29 .prompt_inputs
30 .as_ref()
31 .is_some_and(|inputs| !inputs.related_files.is_empty())
32 {
33 return Ok(());
34 }
35 }
36
37 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
38
39 let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
40
41 let state = example.state.as_ref().unwrap();
42 let project = state.project.clone();
43
44 let _lsp_handle = project.update(&mut cx, |project, cx| {
45 project.register_buffer_with_language_servers(&state.buffer, cx)
46 });
47 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
48
49 let ep_store = cx
50 .update(|cx| EditPredictionStore::try_global(cx))
51 .context("EditPredictionStore not initialized")?;
52
53 let mut events = ep_store.update(&mut cx, |store, cx| {
54 store.register_buffer(&state.buffer, &project, cx);
55 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
56 store.debug_info(&project, cx)
57 });
58
59 while let Some(event) = events.next().await {
60 match event {
61 DebugEvent::ContextRetrievalFinished(_) => {
62 break;
63 }
64 _ => {}
65 }
66 }
67
68 let context_files =
69 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
70
71 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
72 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
73
74 if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
75 prompt_inputs.related_files = context_files;
76 }
77 Ok(())
78}
79
80async fn wait_for_language_servers_to_start(
81 project: &Entity<Project>,
82 buffer: &Entity<Buffer>,
83 step_progress: &Arc<StepProgress>,
84 cx: &mut AsyncApp,
85) -> anyhow::Result<()> {
86 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
87
88 // Determine which servers exist for this buffer, and which are still starting.
89 let mut servers_pending_start = HashSet::default();
90 let mut servers_pending_diagnostics = HashSet::default();
91 buffer.update(cx, |buffer, cx| {
92 lsp_store.update(cx, |lsp_store, cx| {
93 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
94 for &id in &ids {
95 match lsp_store.language_server_statuses.get(&id) {
96 None => {
97 servers_pending_start.insert(id);
98 servers_pending_diagnostics.insert(id);
99 }
100 Some(status) if status.has_pending_diagnostic_updates => {
101 servers_pending_diagnostics.insert(id);
102 }
103 Some(_) => {}
104 }
105 }
106 });
107 });
108
109 step_progress.set_substatus(format!(
110 "waiting for {} LSPs",
111 servers_pending_diagnostics.len()
112 ));
113
114 let timeout_duration = if servers_pending_start.is_empty() {
115 Duration::from_secs(30)
116 } else {
117 Duration::from_secs(60 * 5)
118 };
119 let timeout = cx.background_executor().timer(timeout_duration).shared();
120
121 let (mut started_tx, mut started_rx) = mpsc::channel(servers_pending_start.len().max(1));
122 let (mut diag_tx, mut diag_rx) = mpsc::channel(servers_pending_diagnostics.len().max(1));
123 let subscriptions = [cx.subscribe(&lsp_store, {
124 let step_progress = step_progress.clone();
125 move |lsp_store, event, cx| match event {
126 project::LspStoreEvent::LanguageServerAdded(id, name, _) => {
127 step_progress.set_substatus(format!("LSP started: {}", name));
128 started_tx.try_send(*id).ok();
129 }
130 project::LspStoreEvent::DiskBasedDiagnosticsFinished { language_server_id } => {
131 let name = lsp_store
132 .read(cx)
133 .language_server_adapter_for_id(*language_server_id)
134 .unwrap()
135 .name();
136 step_progress.set_substatus(format!("LSP idle: {}", name));
137 diag_tx.try_send(*language_server_id).ok();
138 }
139 project::LspStoreEvent::LanguageServerUpdate {
140 message:
141 client::proto::update_language_server::Variant::WorkProgress(
142 client::proto::LspWorkProgress {
143 message: Some(message),
144 ..
145 },
146 ),
147 ..
148 } => {
149 step_progress.set_substatus(message.clone());
150 }
151 _ => {}
152 }
153 })];
154
155 // Phase 1: wait for all servers to start.
156 while !servers_pending_start.is_empty() {
157 futures::select! {
158 id = started_rx.next() => {
159 if let Some(id) = id {
160 servers_pending_start.remove(&id);
161 }
162 },
163 _ = timeout.clone().fuse() => {
164 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
165 }
166 }
167 }
168
169 // Save the buffer so the server sees the current content and kicks off diagnostics.
170 project
171 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
172 .await?;
173
174 // Phase 2: wait for all servers to finish their diagnostic pass.
175 while !servers_pending_diagnostics.is_empty() {
176 futures::select! {
177 id = diag_rx.next() => {
178 if let Some(id) = id {
179 servers_pending_diagnostics.remove(&id);
180 }
181 },
182 _ = timeout.clone().fuse() => {
183 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
184 }
185 }
186 }
187
188 drop(subscriptions);
189 step_progress.clear_substatus();
190 Ok(())
191}