retrieve_context.rs

  1use crate::{
  2    example::{Example, ExamplePromptInputs},
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
  6};
  7use anyhow::Context as _;
  8use collections::HashSet;
  9use edit_prediction::{DebugEvent, EditPredictionStore};
 10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 11use gpui::{AsyncApp, Entity};
 12use language::Buffer;
 13use project::Project;
 14use std::sync::Arc;
 15use std::time::Duration;
 16
 17pub async fn run_context_retrieval(
 18    example: &mut Example,
 19    app_state: Arc<EpAppState>,
 20    example_progress: &ExampleProgress,
 21    mut cx: AsyncApp,
 22) -> anyhow::Result<()> {
 23    if example
 24        .prompt_inputs
 25        .as_ref()
 26        .is_some_and(|inputs| inputs.related_files.is_some())
 27    {
 28        return Ok(());
 29    }
 30
 31    if let Some(captured) = &example.spec.captured_prompt_input {
 32        let step_progress = example_progress.start(Step::Context);
 33        step_progress.set_substatus("using captured prompt input");
 34
 35        let edit_history: Vec<Arc<zeta_prompt::Event>> = captured
 36            .events
 37            .iter()
 38            .map(|e| Arc::new(e.to_event()))
 39            .collect();
 40
 41        let related_files: Vec<zeta_prompt::RelatedFile> = captured
 42            .related_files
 43            .iter()
 44            .map(|rf| rf.to_related_file())
 45            .collect();
 46
 47        example.prompt_inputs = Some(ExamplePromptInputs {
 48            content: captured.cursor_file_content.clone(),
 49            cursor_row: captured.cursor_row,
 50            cursor_column: captured.cursor_column,
 51            cursor_offset: captured.cursor_offset,
 52            edit_history,
 53            related_files: Some(related_files),
 54        });
 55
 56        return Ok(());
 57    }
 58
 59    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 60
 61    let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
 62
 63    let state = example.state.as_ref().unwrap();
 64    let project = state.project.clone();
 65
 66    let _lsp_handle = project.update(&mut cx, |project, cx| {
 67        project.register_buffer_with_language_servers(&state.buffer, cx)
 68    });
 69    wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
 70
 71    let ep_store = cx
 72        .update(|cx| EditPredictionStore::try_global(cx))
 73        .context("EditPredictionStore not initialized")?;
 74
 75    let mut events = ep_store.update(&mut cx, |store, cx| {
 76        store.register_buffer(&state.buffer, &project, cx);
 77        store.set_use_context(true);
 78        store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 79        store.debug_info(&project, cx)
 80    });
 81
 82    while let Some(event) = events.next().await {
 83        match event {
 84            DebugEvent::ContextRetrievalFinished(_) => {
 85                break;
 86            }
 87            _ => {}
 88        }
 89    }
 90
 91    let context_files =
 92        ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
 93
 94    let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
 95    step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
 96
 97    if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
 98        prompt_inputs.related_files = Some(context_files);
 99    }
100    Ok(())
101}
102
103async fn wait_for_language_servers_to_start(
104    project: &Entity<Project>,
105    buffer: &Entity<Buffer>,
106    step_progress: &Arc<StepProgress>,
107    cx: &mut AsyncApp,
108) -> anyhow::Result<()> {
109    let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
110
111    let (language_server_ids, mut starting_language_server_ids) =
112        buffer.update(cx, |buffer, cx| {
113            lsp_store.update(cx, |lsp_store, cx| {
114                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
115                let starting_ids = ids
116                    .iter()
117                    .copied()
118                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
119                    .collect::<HashSet<_>>();
120                (ids, starting_ids)
121            })
122        });
123
124    step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
125
126    let timeout_duration = if starting_language_server_ids.is_empty() {
127        Duration::from_secs(30)
128    } else {
129        Duration::from_secs(60 * 5)
130    };
131
132    let timeout = cx.background_executor().timer(timeout_duration).shared();
133
134    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
135    let added_subscription = cx.subscribe(project, {
136        let step_progress = step_progress.clone();
137        move |_, event, _| match event {
138            project::Event::LanguageServerAdded(language_server_id, name, _) => {
139                step_progress.set_substatus(format!("LSP started: {}", name));
140                tx.try_send(*language_server_id).ok();
141            }
142            _ => {}
143        }
144    });
145
146    while !starting_language_server_ids.is_empty() {
147        futures::select! {
148            language_server_id = rx.next() => {
149                if let Some(id) = language_server_id {
150                    starting_language_server_ids.remove(&id);
151                }
152            },
153            _ = timeout.clone().fuse() => {
154                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
155            }
156        }
157    }
158
159    drop(added_subscription);
160
161    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
162    let subscriptions = [
163        cx.subscribe(&lsp_store, {
164            let step_progress = step_progress.clone();
165            move |_, event, _| {
166                if let project::LspStoreEvent::LanguageServerUpdate {
167                    message:
168                        client::proto::update_language_server::Variant::WorkProgress(
169                            client::proto::LspWorkProgress {
170                                message: Some(message),
171                                ..
172                            },
173                        ),
174                    ..
175                } = event
176                {
177                    step_progress.set_substatus(message.clone());
178                }
179            }
180        }),
181        cx.subscribe(project, {
182            let step_progress = step_progress.clone();
183            let lsp_store = lsp_store.clone();
184            move |_, event, cx| match event {
185                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
186                    let lsp_store = lsp_store.read(cx);
187                    let name = lsp_store
188                        .language_server_adapter_for_id(*language_server_id)
189                        .unwrap()
190                        .name();
191                    step_progress.set_substatus(format!("LSP idle: {}", name));
192                    tx.try_send(*language_server_id).ok();
193                }
194                _ => {}
195            }
196        }),
197    ];
198
199    project
200        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
201        .await?;
202
203    let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
204        language_server_ids
205            .iter()
206            .copied()
207            .filter(|id| {
208                lsp_store
209                    .language_server_statuses
210                    .get(id)
211                    .is_some_and(|status| status.has_pending_diagnostic_updates)
212            })
213            .collect::<HashSet<_>>()
214    });
215    while !pending_language_server_ids.is_empty() {
216        futures::select! {
217            language_server_id = rx.next() => {
218                if let Some(id) = language_server_id {
219                    pending_language_server_ids.remove(&id);
220                }
221            },
222            _ = timeout.clone().fuse() => {
223                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
224            }
225        }
226    }
227
228    drop(subscriptions);
229    step_progress.clear_substatus();
230    Ok(())
231}