1use crate::{
2 example::{Example, ExamplePromptInputs},
3 headless::EpAppState,
4 load_project::run_load_project,
5 progress::{ExampleProgress, InfoStyle, Step, StepProgress},
6};
7use anyhow::Context as _;
8use collections::HashSet;
9use edit_prediction::{DebugEvent, EditPredictionStore};
10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
11use gpui::{AsyncApp, Entity};
12use language::Buffer;
13use project::Project;
14use std::sync::Arc;
15use std::time::Duration;
16
17pub async fn run_context_retrieval(
18 example: &mut Example,
19 app_state: Arc<EpAppState>,
20 example_progress: &ExampleProgress,
21 mut cx: AsyncApp,
22) -> anyhow::Result<()> {
23 if example
24 .prompt_inputs
25 .as_ref()
26 .is_some_and(|inputs| inputs.related_files.is_some())
27 {
28 return Ok(());
29 }
30
31 if let Some(captured) = &example.spec.captured_prompt_input {
32 let step_progress = example_progress.start(Step::Context);
33 step_progress.set_substatus("using captured prompt input");
34
35 let edit_history: Vec<Arc<zeta_prompt::Event>> = captured
36 .events
37 .iter()
38 .map(|e| Arc::new(e.to_event()))
39 .collect();
40
41 let related_files: Vec<zeta_prompt::RelatedFile> = captured
42 .related_files
43 .iter()
44 .map(|rf| rf.to_related_file())
45 .collect();
46
47 example.prompt_inputs = Some(ExamplePromptInputs {
48 content: captured.cursor_file_content.clone(),
49 cursor_row: captured.cursor_row,
50 cursor_column: captured.cursor_column,
51 cursor_offset: captured.cursor_offset,
52 edit_history,
53 related_files: Some(related_files),
54 });
55
56 return Ok(());
57 }
58
59 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
60
61 let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
62
63 let state = example.state.as_ref().unwrap();
64 let project = state.project.clone();
65
66 let _lsp_handle = project.update(&mut cx, |project, cx| {
67 project.register_buffer_with_language_servers(&state.buffer, cx)
68 });
69 wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
70
71 let ep_store = cx
72 .update(|cx| EditPredictionStore::try_global(cx))
73 .context("EditPredictionStore not initialized")?;
74
75 let mut events = ep_store.update(&mut cx, |store, cx| {
76 store.register_buffer(&state.buffer, &project, cx);
77 store.set_use_context(true);
78 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
79 store.debug_info(&project, cx)
80 });
81
82 while let Some(event) = events.next().await {
83 match event {
84 DebugEvent::ContextRetrievalFinished(_) => {
85 break;
86 }
87 _ => {}
88 }
89 }
90
91 let context_files =
92 ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
93
94 let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
95 step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
96
97 if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
98 prompt_inputs.related_files = Some(context_files);
99 }
100 Ok(())
101}
102
103async fn wait_for_language_servers_to_start(
104 project: &Entity<Project>,
105 buffer: &Entity<Buffer>,
106 step_progress: &Arc<StepProgress>,
107 cx: &mut AsyncApp,
108) -> anyhow::Result<()> {
109 let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
110
111 let (language_server_ids, mut starting_language_server_ids) =
112 buffer.update(cx, |buffer, cx| {
113 lsp_store.update(cx, |lsp_store, cx| {
114 let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
115 let starting_ids = ids
116 .iter()
117 .copied()
118 .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
119 .collect::<HashSet<_>>();
120 (ids, starting_ids)
121 })
122 });
123
124 step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
125
126 let timeout_duration = if starting_language_server_ids.is_empty() {
127 Duration::from_secs(30)
128 } else {
129 Duration::from_secs(60 * 5)
130 };
131
132 let timeout = cx.background_executor().timer(timeout_duration).shared();
133
134 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
135 let added_subscription = cx.subscribe(project, {
136 let step_progress = step_progress.clone();
137 move |_, event, _| match event {
138 project::Event::LanguageServerAdded(language_server_id, name, _) => {
139 step_progress.set_substatus(format!("LSP started: {}", name));
140 tx.try_send(*language_server_id).ok();
141 }
142 _ => {}
143 }
144 });
145
146 while !starting_language_server_ids.is_empty() {
147 futures::select! {
148 language_server_id = rx.next() => {
149 if let Some(id) = language_server_id {
150 starting_language_server_ids.remove(&id);
151 }
152 },
153 _ = timeout.clone().fuse() => {
154 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
155 }
156 }
157 }
158
159 drop(added_subscription);
160
161 let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
162 let subscriptions = [
163 cx.subscribe(&lsp_store, {
164 let step_progress = step_progress.clone();
165 move |_, event, _| {
166 if let project::LspStoreEvent::LanguageServerUpdate {
167 message:
168 client::proto::update_language_server::Variant::WorkProgress(
169 client::proto::LspWorkProgress {
170 message: Some(message),
171 ..
172 },
173 ),
174 ..
175 } = event
176 {
177 step_progress.set_substatus(message.clone());
178 }
179 }
180 }),
181 cx.subscribe(project, {
182 let step_progress = step_progress.clone();
183 let lsp_store = lsp_store.clone();
184 move |_, event, cx| match event {
185 project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
186 let lsp_store = lsp_store.read(cx);
187 let name = lsp_store
188 .language_server_adapter_for_id(*language_server_id)
189 .unwrap()
190 .name();
191 step_progress.set_substatus(format!("LSP idle: {}", name));
192 tx.try_send(*language_server_id).ok();
193 }
194 _ => {}
195 }
196 }),
197 ];
198
199 project
200 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
201 .await?;
202
203 let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
204 language_server_ids
205 .iter()
206 .copied()
207 .filter(|id| {
208 lsp_store
209 .language_server_statuses
210 .get(id)
211 .is_some_and(|status| status.has_pending_diagnostic_updates)
212 })
213 .collect::<HashSet<_>>()
214 });
215 while !pending_language_server_ids.is_empty() {
216 futures::select! {
217 language_server_id = rx.next() => {
218 if let Some(id) = language_server_id {
219 pending_language_server_ids.remove(&id);
220 }
221 },
222 _ = timeout.clone().fuse() => {
223 return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
224 }
225 }
226 }
227
228 drop(subscriptions);
229 step_progress.clear_substatus();
230 Ok(())
231}