1use crate::{
2 example::{Example, ExamplePromptInputs},
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 example_progress: &ExampleProgress,
21 mut cx: AsyncApp,
22) -> anyhow::Result<()> {
23 if example
24 .prompt_inputs
25 .as_ref()
26 .is_some_and(|inputs| inputs.related_files.is_some())
27 {
28 return Ok(());
29 }
30
31 if let Some(captured) = &example.spec.captured_prompt_input {
32 let step_progress = example_progress.start(Step::Context);
33 step_progress.set_substatus("using captured prompt input");
34
35 let edit_history: Vec<Arc<zeta_prompt::Event>> = captured
36 .events
37 .iter()
38 .map(|e| Arc::new(e.to_event()))
39 .collect();
40
41 let related_files: Vec<zeta_prompt::RelatedFile> = captured
42 .related_files
43 .iter()
44 .map(|rf| rf.to_related_file())
45 .collect();
46
47 example.prompt_inputs = Some(ExamplePromptInputs {
48 content: captured.cursor_file_content.clone(),
49 cursor_row: captured.cursor_row,
50 cursor_column: captured.cursor_column,
51 cursor_offset: captured.cursor_offset,
52 edit_history,
53 related_files: Some(related_files),
54 });
55
56 return Ok(());
57 }
58
59 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
60
61 let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
62
63 let state = example.state.as_ref().unwrap();
64 let project = state.project.clone();
65
66 let _lsp_handle = project.update(&mut cx, |project, cx| {
67 project.register_buffer_with_language_servers(&state.buffer, cx)
68 });
69 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
70
71 let ep_store = cx
72 .update(|cx| EditPredictionStore::try_global(cx))
73 .context("EditPredictionStore not initialized")?;
74
75 let mut events = ep_store.update(&mut cx, |store, cx| {
76 store.register_buffer(&state.buffer, &project, cx);
77 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
78 store.debug_info(&project, cx)
79 });
80
81 while let Some(event) = events.next().await {
82 match event {
83 DebugEvent::ContextRetrievalFinished(_) => {
84 break;
85 }
86 _ => {}
87 }
88 }
89
90 let context_files =
91 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
92
93 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
94 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
95
96 if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
97 prompt_inputs.related_files = Some(context_files);
98 }
99 Ok(())
100}
101
102async fn wait_for_language_servers_to_start(
103 project: &Entity<Project>,
104 buffer: &Entity<Buffer>,
105 step_progress: &Arc<StepProgress>,
106 cx: &mut AsyncApp,
107) -> anyhow::Result<()> {
108 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
109
110 let (language_server_ids, mut starting_language_server_ids) =
111 buffer.update(cx, |buffer, cx| {
112 lsp_store.update(cx, |lsp_store, cx| {
113 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
114 let starting_ids = ids
115 .iter()
116 .copied()
117 .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
118 .collect::<HashSet<_>>();
119 (ids, starting_ids)
120 })
121 });
122
123 step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
124
125 let timeout_duration = if starting_language_server_ids.is_empty() {
126 Duration::from_secs(30)
127 } else {
128 Duration::from_secs(60 * 5)
129 };
130
131 let timeout = cx.background_executor().timer(timeout_duration).shared();
132
133 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
134 let added_subscription = cx.subscribe(project, {
135 let step_progress = step_progress.clone();
136 move |_, event, _| match event {
137 project::Event::LanguageServerAdded(language_server_id, name, _) => {
138 step_progress.set_substatus(format!("LSP started: {}", name));
139 tx.try_send(*language_server_id).ok();
140 }
141 _ => {}
142 }
143 });
144
145 while !starting_language_server_ids.is_empty() {
146 futures::select! {
147 language_server_id = rx.next() => {
148 if let Some(id) = language_server_id {
149 starting_language_server_ids.remove(&id);
150 }
151 },
152 _ = timeout.clone().fuse() => {
153 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
154 }
155 }
156 }
157
158 drop(added_subscription);
159
160 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
161 let subscriptions = [
162 cx.subscribe(&lsp_store, {
163 let step_progress = step_progress.clone();
164 move |_, event, _| {
165 if let project::LspStoreEvent::LanguageServerUpdate {
166 message:
167 client::proto::update_language_server::Variant::WorkProgress(
168 client::proto::LspWorkProgress {
169 message: Some(message),
170 ..
171 },
172 ),
173 ..
174 } = event
175 {
176 step_progress.set_substatus(message.clone());
177 }
178 }
179 }),
180 cx.subscribe(project, {
181 let step_progress = step_progress.clone();
182 let lsp_store = lsp_store.clone();
183 move |_, event, cx| match event {
184 project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
185 let lsp_store = lsp_store.read(cx);
186 let name = lsp_store
187 .language_server_adapter_for_id(*language_server_id)
188 .unwrap()
189 .name();
190 step_progress.set_substatus(format!("LSP idle: {}", name));
191 tx.try_send(*language_server_id).ok();
192 }
193 _ => {}
194 }
195 }),
196 ];
197
198 project
199 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
200 .await?;
201
202 let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
203 language_server_ids
204 .iter()
205 .copied()
206 .filter(|id| {
207 lsp_store
208 .language_server_statuses
209 .get(id)
210 .is_some_and(|status| status.has_pending_diagnostic_updates)
211 })
212 .collect::<HashSet<_>>()
213 });
214 while !pending_language_server_ids.is_empty() {
215 futures::select! {
216 language_server_id = rx.next() => {
217 if let Some(id) = language_server_id {
218 pending_language_server_ids.remove(&id);
219 }
220 },
221 _ = timeout.clone().fuse() => {
222 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
223 }
224 }
225 }
226
227 drop(subscriptions);
228 step_progress.clear_substatus();
229 Ok(())
230}