1use crate::{
2 example::Example,
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{InfoStyle, Progress, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 mut cx: AsyncApp,
21) -> anyhow::Result<()> {
22 if example
23 .prompt_inputs
24 .as_ref()
25 .is_some_and(|inputs| inputs.related_files.is_some())
26 {
27 return Ok(());
28 }
29
30 run_load_project(example, app_state.clone(), cx.clone()).await?;
31
32 let step_progress: Arc<StepProgress> = Progress::global()
33 .start(Step::Context, &example.spec.name)
34 .into();
35
36 let state = example.state.as_ref().unwrap();
37 let project = state.project.clone();
38
39 let _lsp_handle = project.update(&mut cx, |project, cx| {
40 project.register_buffer_with_language_servers(&state.buffer, cx)
41 });
42 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
43
44 let ep_store = cx
45 .update(|cx| EditPredictionStore::try_global(cx))
46 .context("EditPredictionStore not initialized")?;
47
48 let mut events = ep_store.update(&mut cx, |store, cx| {
49 store.register_buffer(&state.buffer, &project, cx);
50 store.set_use_context(true);
51 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
52 store.debug_info(&project, cx)
53 });
54
55 while let Some(event) = events.next().await {
56 match event {
57 DebugEvent::ContextRetrievalFinished(_) => {
58 break;
59 }
60 _ => {}
61 }
62 }
63
64 let context_files =
65 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
66
67 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
68 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
69
70 if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
71 prompt_inputs.related_files = Some(context_files);
72 }
73 Ok(())
74}
75
76async fn wait_for_language_servers_to_start(
77 project: &Entity<Project>,
78 buffer: &Entity<Buffer>,
79 step_progress: &Arc<StepProgress>,
80 cx: &mut AsyncApp,
81) -> anyhow::Result<()> {
82 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
83
84 let (language_server_ids, mut starting_language_server_ids) =
85 buffer.update(cx, |buffer, cx| {
86 lsp_store.update(cx, |lsp_store, cx| {
87 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
88 let starting_ids = ids
89 .iter()
90 .copied()
91 .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
92 .collect::<HashSet<_>>();
93 (ids, starting_ids)
94 })
95 });
96
97 step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
98
99 let timeout = cx
100 .background_executor()
101 .timer(Duration::from_secs(60 * 5))
102 .shared();
103
104 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
105 let added_subscription = cx.subscribe(project, {
106 let step_progress = step_progress.clone();
107 move |_, event, _| match event {
108 project::Event::LanguageServerAdded(language_server_id, name, _) => {
109 step_progress.set_substatus(format!("LSP started: {}", name));
110 tx.try_send(*language_server_id).ok();
111 }
112 _ => {}
113 }
114 });
115
116 while !starting_language_server_ids.is_empty() {
117 futures::select! {
118 language_server_id = rx.next() => {
119 if let Some(id) = language_server_id {
120 starting_language_server_ids.remove(&id);
121 }
122 },
123 _ = timeout.clone().fuse() => {
124 return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
125 }
126 }
127 }
128
129 drop(added_subscription);
130
131 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
132 let subscriptions = [
133 cx.subscribe(&lsp_store, {
134 let step_progress = step_progress.clone();
135 move |_, event, _| {
136 if let project::LspStoreEvent::LanguageServerUpdate {
137 message:
138 client::proto::update_language_server::Variant::WorkProgress(
139 client::proto::LspWorkProgress {
140 message: Some(message),
141 ..
142 },
143 ),
144 ..
145 } = event
146 {
147 step_progress.set_substatus(message.clone());
148 }
149 }
150 }),
151 cx.subscribe(project, {
152 let step_progress = step_progress.clone();
153 let lsp_store = lsp_store.clone();
154 move |_, event, cx| match event {
155 project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
156 let lsp_store = lsp_store.read(cx);
157 let name = lsp_store
158 .language_server_adapter_for_id(*language_server_id)
159 .unwrap()
160 .name();
161 step_progress.set_substatus(format!("LSP idle: {}", name));
162 tx.try_send(*language_server_id).ok();
163 }
164 _ => {}
165 }
166 }),
167 ];
168
169 project
170 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
171 .await?;
172
173 let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
174 language_server_ids
175 .iter()
176 .copied()
177 .filter(|id| {
178 lsp_store
179 .language_server_statuses
180 .get(id)
181 .is_some_and(|status| status.has_pending_diagnostic_updates)
182 })
183 .collect::<HashSet<_>>()
184 });
185 while !pending_language_server_ids.is_empty() {
186 futures::select! {
187 language_server_id = rx.next() => {
188 if let Some(id) = language_server_id {
189 pending_language_server_ids.remove(&id);
190 }
191 },
192 _ = timeout.clone().fuse() => {
193 return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
194 }
195 }
196 }
197
198 drop(subscriptions);
199 step_progress.clear_substatus();
200 Ok(())
201}