retrieve_context.rs

  1use crate::{
  2    example::Example,
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
  6};
  7use anyhow::Context as _;
  8use collections::HashSet;
  9use edit_prediction::{DebugEvent, EditPredictionStore};
 10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 11use gpui::{AsyncApp, Entity};
 12use language::Buffer;
 13use project::Project;
 14use std::sync::Arc;
 15use std::time::Duration;
 16
 17pub async fn run_context_retrieval(
 18    example: &mut Example,
 19    app_state: Arc<EpAppState>,
 20    example_progress: &ExampleProgress,
 21    mut cx: AsyncApp,
 22) -> anyhow::Result<()> {
 23    if example
 24        .prompt_inputs
 25        .as_ref()
 26        .is_some_and(|inputs| inputs.related_files.is_some())
 27        || example.spec.repository_url.is_empty()
 28    {
 29        return Ok(());
 30    }
 31
 32    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 33
 34    let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
 35
 36    let state = example.state.as_ref().unwrap();
 37    let project = state.project.clone();
 38
 39    let _lsp_handle = project.update(&mut cx, |project, cx| {
 40        project.register_buffer_with_language_servers(&state.buffer, cx)
 41    });
 42    wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
 43
 44    let ep_store = cx
 45        .update(|cx| EditPredictionStore::try_global(cx))
 46        .context("EditPredictionStore not initialized")?;
 47
 48    let mut events = ep_store.update(&mut cx, |store, cx| {
 49        store.register_buffer(&state.buffer, &project, cx);
 50        store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 51        store.debug_info(&project, cx)
 52    });
 53
 54    while let Some(event) = events.next().await {
 55        match event {
 56            DebugEvent::ContextRetrievalFinished(_) => {
 57                break;
 58            }
 59            _ => {}
 60        }
 61    }
 62
 63    let context_files =
 64        ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
 65
 66    let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
 67    step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
 68
 69    if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
 70        prompt_inputs.related_files = Some(context_files);
 71    }
 72    Ok(())
 73}
 74
 75async fn wait_for_language_servers_to_start(
 76    project: &Entity<Project>,
 77    buffer: &Entity<Buffer>,
 78    step_progress: &Arc<StepProgress>,
 79    cx: &mut AsyncApp,
 80) -> anyhow::Result<()> {
 81    let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
 82
 83    // Determine which servers exist for this buffer, and which are still starting.
 84    let mut servers_pending_start = HashSet::default();
 85    let mut servers_pending_diagnostics = HashSet::default();
 86    buffer.update(cx, |buffer, cx| {
 87        lsp_store.update(cx, |lsp_store, cx| {
 88            let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
 89            for &id in &ids {
 90                match lsp_store.language_server_statuses.get(&id) {
 91                    None => {
 92                        servers_pending_start.insert(id);
 93                        servers_pending_diagnostics.insert(id);
 94                    }
 95                    Some(status) if status.has_pending_diagnostic_updates => {
 96                        servers_pending_diagnostics.insert(id);
 97                    }
 98                    Some(_) => {}
 99                }
100            }
101        });
102    });
103
104    step_progress.set_substatus(format!(
105        "waiting for {} LSPs",
106        servers_pending_diagnostics.len()
107    ));
108
109    let timeout_duration = if servers_pending_start.is_empty() {
110        Duration::from_secs(30)
111    } else {
112        Duration::from_secs(60 * 5)
113    };
114    let timeout = cx.background_executor().timer(timeout_duration).shared();
115
116    let (mut started_tx, mut started_rx) = mpsc::channel(servers_pending_start.len().max(1));
117    let (mut diag_tx, mut diag_rx) = mpsc::channel(servers_pending_diagnostics.len().max(1));
118    let subscriptions = [cx.subscribe(&lsp_store, {
119        let step_progress = step_progress.clone();
120        move |lsp_store, event, cx| match event {
121            project::LspStoreEvent::LanguageServerAdded(id, name, _) => {
122                step_progress.set_substatus(format!("LSP started: {}", name));
123                started_tx.try_send(*id).ok();
124            }
125            project::LspStoreEvent::DiskBasedDiagnosticsFinished { language_server_id } => {
126                let name = lsp_store
127                    .read(cx)
128                    .language_server_adapter_for_id(*language_server_id)
129                    .unwrap()
130                    .name();
131                step_progress.set_substatus(format!("LSP idle: {}", name));
132                diag_tx.try_send(*language_server_id).ok();
133            }
134            project::LspStoreEvent::LanguageServerUpdate {
135                message:
136                    client::proto::update_language_server::Variant::WorkProgress(
137                        client::proto::LspWorkProgress {
138                            message: Some(message),
139                            ..
140                        },
141                    ),
142                ..
143            } => {
144                step_progress.set_substatus(message.clone());
145            }
146            _ => {}
147        }
148    })];
149
150    // Phase 1: wait for all servers to start.
151    while !servers_pending_start.is_empty() {
152        futures::select! {
153            id = started_rx.next() => {
154                if let Some(id) = id {
155                    servers_pending_start.remove(&id);
156                }
157            },
158            _ = timeout.clone().fuse() => {
159                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
160            }
161        }
162    }
163
164    // Save the buffer so the server sees the current content and kicks off diagnostics.
165    project
166        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
167        .await?;
168
169    // Phase 2: wait for all servers to finish their diagnostic pass.
170    while !servers_pending_diagnostics.is_empty() {
171        futures::select! {
172            id = diag_rx.next() => {
173                if let Some(id) = id {
174                    servers_pending_diagnostics.remove(&id);
175                }
176            },
177            _ = timeout.clone().fuse() => {
178                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
179            }
180        }
181    }
182
183    drop(subscriptions);
184    step_progress.clear_substatus();
185    Ok(())
186}