1use crate::{
2 example::Example,
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 example_progress: &ExampleProgress,
21 mut cx: AsyncApp,
22) -> anyhow::Result<()> {
23 if example
24 .prompt_inputs
25 .as_ref()
26 .is_some_and(|inputs| inputs.related_files.is_some())
27 {
28 return Ok(());
29 }
30
31 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
32
33 let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
34
35 let state = example.state.as_ref().unwrap();
36 let project = state.project.clone();
37
38 let _lsp_handle = project.update(&mut cx, |project, cx| {
39 project.register_buffer_with_language_servers(&state.buffer, cx)
40 });
41 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
42
43 let ep_store = cx
44 .update(|cx| EditPredictionStore::try_global(cx))
45 .context("EditPredictionStore not initialized")?;
46
47 let mut events = ep_store.update(&mut cx, |store, cx| {
48 store.register_buffer(&state.buffer, &project, cx);
49 store.set_use_context(true);
50 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
51 store.debug_info(&project, cx)
52 });
53
54 while let Some(event) = events.next().await {
55 match event {
56 DebugEvent::ContextRetrievalFinished(_) => {
57 break;
58 }
59 _ => {}
60 }
61 }
62
63 let context_files =
64 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
65
66 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
67 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
68
69 if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
70 prompt_inputs.related_files = Some(context_files);
71 }
72 Ok(())
73}
74
75async fn wait_for_language_servers_to_start(
76 project: &Entity<Project>,
77 buffer: &Entity<Buffer>,
78 step_progress: &Arc<StepProgress>,
79 cx: &mut AsyncApp,
80) -> anyhow::Result<()> {
81 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
82
83 let (language_server_ids, mut starting_language_server_ids) =
84 buffer.update(cx, |buffer, cx| {
85 lsp_store.update(cx, |lsp_store, cx| {
86 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
87 let starting_ids = ids
88 .iter()
89 .copied()
90 .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
91 .collect::<HashSet<_>>();
92 (ids, starting_ids)
93 })
94 });
95
96 step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
97
98 let timeout_duration = if starting_language_server_ids.is_empty() {
99 Duration::from_secs(30)
100 } else {
101 Duration::from_secs(60 * 5)
102 };
103
104 let timeout = cx.background_executor().timer(timeout_duration).shared();
105
106 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
107 let added_subscription = cx.subscribe(project, {
108 let step_progress = step_progress.clone();
109 move |_, event, _| match event {
110 project::Event::LanguageServerAdded(language_server_id, name, _) => {
111 step_progress.set_substatus(format!("LSP started: {}", name));
112 tx.try_send(*language_server_id).ok();
113 }
114 _ => {}
115 }
116 });
117
118 while !starting_language_server_ids.is_empty() {
119 futures::select! {
120 language_server_id = rx.next() => {
121 if let Some(id) = language_server_id {
122 starting_language_server_ids.remove(&id);
123 }
124 },
125 _ = timeout.clone().fuse() => {
126 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
127 }
128 }
129 }
130
131 drop(added_subscription);
132
133 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
134 let subscriptions = [
135 cx.subscribe(&lsp_store, {
136 let step_progress = step_progress.clone();
137 move |_, event, _| {
138 if let project::LspStoreEvent::LanguageServerUpdate {
139 message:
140 client::proto::update_language_server::Variant::WorkProgress(
141 client::proto::LspWorkProgress {
142 message: Some(message),
143 ..
144 },
145 ),
146 ..
147 } = event
148 {
149 step_progress.set_substatus(message.clone());
150 }
151 }
152 }),
153 cx.subscribe(project, {
154 let step_progress = step_progress.clone();
155 let lsp_store = lsp_store.clone();
156 move |_, event, cx| match event {
157 project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
158 let lsp_store = lsp_store.read(cx);
159 let name = lsp_store
160 .language_server_adapter_for_id(*language_server_id)
161 .unwrap()
162 .name();
163 step_progress.set_substatus(format!("LSP idle: {}", name));
164 tx.try_send(*language_server_id).ok();
165 }
166 _ => {}
167 }
168 }),
169 ];
170
171 project
172 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
173 .await?;
174
175 let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
176 language_server_ids
177 .iter()
178 .copied()
179 .filter(|id| {
180 lsp_store
181 .language_server_statuses
182 .get(id)
183 .is_some_and(|status| status.has_pending_diagnostic_updates)
184 })
185 .collect::<HashSet<_>>()
186 });
187 while !pending_language_server_ids.is_empty() {
188 futures::select! {
189 language_server_id = rx.next() => {
190 if let Some(id) = language_server_id {
191 pending_language_server_ids.remove(&id);
192 }
193 },
194 _ = timeout.clone().fuse() => {
195 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
196 }
197 }
198 }
199
200 drop(subscriptions);
201 step_progress.clear_substatus();
202 Ok(())
203}