1use crate::{
2 example::{Example, ExampleContext},
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{InfoStyle, Progress, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 mut cx: AsyncApp,
21) -> anyhow::Result<()> {
22 if example.context.is_some() {
23 return Ok(());
24 }
25
26 run_load_project(example, app_state.clone(), cx.clone()).await?;
27
28 let step_progress: Arc<StepProgress> = Progress::global()
29 .start(Step::Context, &example.spec.name)
30 .into();
31
32 let state = example.state.as_ref().unwrap();
33 let project = state.project.clone();
34
35 let _lsp_handle = project.update(&mut cx, |project, cx| {
36 project.register_buffer_with_language_servers(&state.buffer, cx)
37 });
38 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
39
40 let ep_store = cx
41 .update(|cx| EditPredictionStore::try_global(cx))
42 .context("EditPredictionStore not initialized")?;
43
44 let mut events = ep_store.update(&mut cx, |store, cx| {
45 store.register_buffer(&state.buffer, &project, cx);
46 store.set_use_context(true);
47 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
48 store.debug_info(&project, cx)
49 });
50
51 while let Some(event) = events.next().await {
52 match event {
53 DebugEvent::ContextRetrievalFinished(_) => {
54 break;
55 }
56 _ => {}
57 }
58 }
59
60 let context_files =
61 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
62
63 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
64 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
65
66 example.context = Some(ExampleContext {
67 files: context_files,
68 });
69 Ok(())
70}
71
72async fn wait_for_language_servers_to_start(
73 project: &Entity<Project>,
74 buffer: &Entity<Buffer>,
75 step_progress: &Arc<StepProgress>,
76 cx: &mut AsyncApp,
77) -> anyhow::Result<()> {
78 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
79
80 let (language_server_ids, mut starting_language_server_ids) =
81 buffer.update(cx, |buffer, cx| {
82 lsp_store.update(cx, |lsp_store, cx| {
83 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
84 let starting_ids = ids
85 .iter()
86 .copied()
87 .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
88 .collect::<HashSet<_>>();
89 (ids, starting_ids)
90 })
91 });
92
93 step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
94
95 let timeout = cx
96 .background_executor()
97 .timer(Duration::from_secs(60 * 5))
98 .shared();
99
100 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
101 let added_subscription = cx.subscribe(project, {
102 let step_progress = step_progress.clone();
103 move |_, event, _| match event {
104 project::Event::LanguageServerAdded(language_server_id, name, _) => {
105 step_progress.set_substatus(format!("LSP started: {}", name));
106 tx.try_send(*language_server_id).ok();
107 }
108 _ => {}
109 }
110 });
111
112 while !starting_language_server_ids.is_empty() {
113 futures::select! {
114 language_server_id = rx.next() => {
115 if let Some(id) = language_server_id {
116 starting_language_server_ids.remove(&id);
117 }
118 },
119 _ = timeout.clone().fuse() => {
120 return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
121 }
122 }
123 }
124
125 drop(added_subscription);
126
127 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
128 let subscriptions = [
129 cx.subscribe(&lsp_store, {
130 let step_progress = step_progress.clone();
131 move |_, event, _| {
132 if let project::LspStoreEvent::LanguageServerUpdate {
133 message:
134 client::proto::update_language_server::Variant::WorkProgress(
135 client::proto::LspWorkProgress {
136 message: Some(message),
137 ..
138 },
139 ),
140 ..
141 } = event
142 {
143 step_progress.set_substatus(message.clone());
144 }
145 }
146 }),
147 cx.subscribe(project, {
148 let step_progress = step_progress.clone();
149 let lsp_store = lsp_store.clone();
150 move |_, event, cx| match event {
151 project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
152 let lsp_store = lsp_store.read(cx);
153 let name = lsp_store
154 .language_server_adapter_for_id(*language_server_id)
155 .unwrap()
156 .name();
157 step_progress.set_substatus(format!("LSP idle: {}", name));
158 tx.try_send(*language_server_id).ok();
159 }
160 _ => {}
161 }
162 }),
163 ];
164
165 project
166 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
167 .await?;
168
169 let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
170 language_server_ids
171 .iter()
172 .copied()
173 .filter(|id| {
174 lsp_store
175 .language_server_statuses
176 .get(id)
177 .is_some_and(|status| status.has_pending_diagnostic_updates)
178 })
179 .collect::<HashSet<_>>()
180 });
181 while !pending_language_server_ids.is_empty() {
182 futures::select! {
183 language_server_id = rx.next() => {
184 if let Some(id) = language_server_id {
185 pending_language_server_ids.remove(&id);
186 }
187 },
188 _ = timeout.clone().fuse() => {
189 return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
190 }
191 }
192 }
193
194 drop(subscriptions);
195 step_progress.clear_substatus();
196 Ok(())
197}