1use crate::{
2 example::Example,
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 example_progress: &ExampleProgress,
21 mut cx: AsyncApp,
22) -> anyhow::Result<()> {
23 if example
24 .prompt_inputs
25 .as_ref()
26 .is_some_and(|inputs| inputs.related_files.is_some())
27 {
28 return Ok(());
29 }
30
31 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
32
33 let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
34
35 let state = example.state.as_ref().unwrap();
36 let project = state.project.clone();
37
38 let _lsp_handle = project.update(&mut cx, |project, cx| {
39 project.register_buffer_with_language_servers(&state.buffer, cx)
40 });
41 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
42
43 let ep_store = cx
44 .update(|cx| EditPredictionStore::try_global(cx))
45 .context("EditPredictionStore not initialized")?;
46
47 let mut events = ep_store.update(&mut cx, |store, cx| {
48 store.register_buffer(&state.buffer, &project, cx);
49 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
50 store.debug_info(&project, cx)
51 });
52
53 while let Some(event) = events.next().await {
54 match event {
55 DebugEvent::ContextRetrievalFinished(_) => {
56 break;
57 }
58 _ => {}
59 }
60 }
61
62 let context_files =
63 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
64
65 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
66 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
67
68 if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
69 prompt_inputs.related_files = Some(context_files);
70 }
71 Ok(())
72}
73
74async fn wait_for_language_servers_to_start(
75 project: &Entity<Project>,
76 buffer: &Entity<Buffer>,
77 step_progress: &Arc<StepProgress>,
78 cx: &mut AsyncApp,
79) -> anyhow::Result<()> {
80 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
81
82 let (language_server_ids, mut starting_language_server_ids) =
83 buffer.update(cx, |buffer, cx| {
84 lsp_store.update(cx, |lsp_store, cx| {
85 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
86 let starting_ids = ids
87 .iter()
88 .copied()
89 .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
90 .collect::<HashSet<_>>();
91 (ids, starting_ids)
92 })
93 });
94
95 step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
96
97 let timeout_duration = if starting_language_server_ids.is_empty() {
98 Duration::from_secs(30)
99 } else {
100 Duration::from_secs(60 * 5)
101 };
102
103 let timeout = cx.background_executor().timer(timeout_duration).shared();
104
105 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
106 let added_subscription = cx.subscribe(project, {
107 let step_progress = step_progress.clone();
108 move |_, event, _| match event {
109 project::Event::LanguageServerAdded(language_server_id, name, _) => {
110 step_progress.set_substatus(format!("LSP started: {}", name));
111 tx.try_send(*language_server_id).ok();
112 }
113 _ => {}
114 }
115 });
116
117 while !starting_language_server_ids.is_empty() {
118 futures::select! {
119 language_server_id = rx.next() => {
120 if let Some(id) = language_server_id {
121 starting_language_server_ids.remove(&id);
122 }
123 },
124 _ = timeout.clone().fuse() => {
125 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
126 }
127 }
128 }
129
130 drop(added_subscription);
131
132 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
133 let subscriptions = [
134 cx.subscribe(&lsp_store, {
135 let step_progress = step_progress.clone();
136 move |_, event, _| {
137 if let project::LspStoreEvent::LanguageServerUpdate {
138 message:
139 client::proto::update_language_server::Variant::WorkProgress(
140 client::proto::LspWorkProgress {
141 message: Some(message),
142 ..
143 },
144 ),
145 ..
146 } = event
147 {
148 step_progress.set_substatus(message.clone());
149 }
150 }
151 }),
152 cx.subscribe(project, {
153 let step_progress = step_progress.clone();
154 let lsp_store = lsp_store.clone();
155 move |_, event, cx| match event {
156 project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
157 let lsp_store = lsp_store.read(cx);
158 let name = lsp_store
159 .language_server_adapter_for_id(*language_server_id)
160 .unwrap()
161 .name();
162 step_progress.set_substatus(format!("LSP idle: {}", name));
163 tx.try_send(*language_server_id).ok();
164 }
165 _ => {}
166 }
167 }),
168 ];
169
170 project
171 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
172 .await?;
173
174 let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
175 language_server_ids
176 .iter()
177 .copied()
178 .filter(|id| {
179 lsp_store
180 .language_server_statuses
181 .get(id)
182 .is_some_and(|status| status.has_pending_diagnostic_updates)
183 })
184 .collect::<HashSet<_>>()
185 });
186 while !pending_language_server_ids.is_empty() {
187 futures::select! {
188 language_server_id = rx.next() => {
189 if let Some(id) = language_server_id {
190 pending_language_server_ids.remove(&id);
191 }
192 },
193 _ = timeout.clone().fuse() => {
194 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
195 }
196 }
197 }
198
199 drop(subscriptions);
200 step_progress.clear_substatus();
201 Ok(())
202}