1use crate::{
2 example::Example,
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 example_progress: &ExampleProgress,
21 mut cx: AsyncApp,
22) -> anyhow::Result<()> {
23 if example.prompt_inputs.is_some() {
24 if example.spec.repository_url.is_empty() {
25 return Ok(());
26 }
27
28 if example
29 .prompt_inputs
30 .as_ref()
31 .is_some_and(|inputs| !inputs.related_files.is_empty())
32 {
33 return Ok(());
34 }
35 }
36
37 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
38
39 let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
40
41 let state = example.state.as_ref().unwrap();
42 let project = state.project.clone();
43
44 let _lsp_handle = project.update(&mut cx, |project, cx| {
45 project.register_buffer_with_language_servers(&state.buffer, cx)
46 });
47 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
48
49 let ep_store = cx
50 .update(|cx| EditPredictionStore::try_global(cx))
51 .context("EditPredictionStore not initialized")?;
52
53 let mut events = ep_store.update(&mut cx, |store, cx| {
54 store.register_buffer(&state.buffer, &project, cx);
55 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
56 store.debug_info(&project, cx)
57 });
58
59 while let Some(event) = events.next().await {
60 match event {
61 DebugEvent::ContextRetrievalFinished(_) => {
62 break;
63 }
64 _ => {}
65 }
66 }
67
68 let context_files =
69 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
70
71 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
72 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
73
74 if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
75 prompt_inputs.related_files = context_files;
76 }
77 Ok(())
78}
79
80async fn wait_for_language_servers_to_start(
81 project: &Entity<Project>,
82 buffer: &Entity<Buffer>,
83 step_progress: &Arc<StepProgress>,
84 cx: &mut AsyncApp,
85) -> anyhow::Result<()> {
86 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
87
88 let (language_server_ids, mut starting_language_server_ids) =
89 buffer.update(cx, |buffer, cx| {
90 lsp_store.update(cx, |lsp_store, cx| {
91 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
92 let starting_ids = ids
93 .iter()
94 .copied()
95 .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
96 .collect::<HashSet<_>>();
97 (ids, starting_ids)
98 })
99 });
100
101 step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
102
103 let timeout_duration = if starting_language_server_ids.is_empty() {
104 Duration::from_secs(30)
105 } else {
106 Duration::from_secs(60 * 5)
107 };
108
109 let timeout = cx.background_executor().timer(timeout_duration).shared();
110
111 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
112 let added_subscription = cx.subscribe(project, {
113 let step_progress = step_progress.clone();
114 move |_, event, _| match event {
115 project::Event::LanguageServerAdded(language_server_id, name, _) => {
116 step_progress.set_substatus(format!("LSP started: {}", name));
117 tx.try_send(*language_server_id).ok();
118 }
119 _ => {}
120 }
121 });
122
123 while !starting_language_server_ids.is_empty() {
124 futures::select! {
125 language_server_id = rx.next() => {
126 if let Some(id) = language_server_id {
127 starting_language_server_ids.remove(&id);
128 }
129 },
130 _ = timeout.clone().fuse() => {
131 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
132 }
133 }
134 }
135
136 drop(added_subscription);
137
138 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
139 let subscriptions = [
140 cx.subscribe(&lsp_store, {
141 let step_progress = step_progress.clone();
142 move |_, event, _| {
143 if let project::LspStoreEvent::LanguageServerUpdate {
144 message:
145 client::proto::update_language_server::Variant::WorkProgress(
146 client::proto::LspWorkProgress {
147 message: Some(message),
148 ..
149 },
150 ),
151 ..
152 } = event
153 {
154 step_progress.set_substatus(message.clone());
155 }
156 }
157 }),
158 cx.subscribe(project, {
159 let step_progress = step_progress.clone();
160 let lsp_store = lsp_store.clone();
161 move |_, event, cx| match event {
162 project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
163 let lsp_store = lsp_store.read(cx);
164 let name = lsp_store
165 .language_server_adapter_for_id(*language_server_id)
166 .unwrap()
167 .name();
168 step_progress.set_substatus(format!("LSP idle: {}", name));
169 tx.try_send(*language_server_id).ok();
170 }
171 _ => {}
172 }
173 }),
174 ];
175
176 project
177 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
178 .await?;
179
180 let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
181 language_server_ids
182 .iter()
183 .copied()
184 .filter(|id| {
185 lsp_store
186 .language_server_statuses
187 .get(id)
188 .is_some_and(|status| status.has_pending_diagnostic_updates)
189 })
190 .collect::<HashSet<_>>()
191 });
192 while !pending_language_server_ids.is_empty() {
193 futures::select! {
194 language_server_id = rx.next() => {
195 if let Some(id) = language_server_id {
196 pending_language_server_ids.remove(&id);
197 }
198 },
199 _ = timeout.clone().fuse() => {
200 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
201 }
202 }
203 }
204
205 drop(subscriptions);
206 step_progress.clear_substatus();
207 Ok(())
208}