retrieve_context.rs

  1use crate::{
  2    example::{Example, ExampleContext},
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5    progress::{InfoStyle, Progress, Step, StepProgress},
  6};
  7use anyhow::Context as _;
  8use collections::HashSet;
  9use edit_prediction::{DebugEvent, EditPredictionStore};
 10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 11use gpui::{AsyncApp, Entity};
 12use language::Buffer;
 13use project::Project;
 14use std::sync::Arc;
 15use std::time::Duration;
 16
 17pub async fn run_context_retrieval(
 18    example: &mut Example,
 19    app_state: Arc<EpAppState>,
 20    mut cx: AsyncApp,
 21) -> anyhow::Result<()> {
 22    if example.context.is_some() {
 23        return Ok(());
 24    }
 25
 26    run_load_project(example, app_state.clone(), cx.clone()).await?;
 27
 28    let step_progress: Arc<StepProgress> = Progress::global()
 29        .start(Step::Context, &example.spec.name)
 30        .into();
 31
 32    let state = example.state.as_ref().unwrap();
 33    let project = state.project.clone();
 34
 35    let _lsp_handle = project.update(&mut cx, |project, cx| {
 36        project.register_buffer_with_language_servers(&state.buffer, cx)
 37    });
 38    wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
 39
 40    let ep_store = cx
 41        .update(|cx| EditPredictionStore::try_global(cx))
 42        .context("EditPredictionStore not initialized")?;
 43
 44    let mut events = ep_store.update(&mut cx, |store, cx| {
 45        store.register_buffer(&state.buffer, &project, cx);
 46        store.set_use_context(true);
 47        store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 48        store.debug_info(&project, cx)
 49    });
 50
 51    while let Some(event) = events.next().await {
 52        match event {
 53            DebugEvent::ContextRetrievalFinished(_) => {
 54                break;
 55            }
 56            _ => {}
 57        }
 58    }
 59
 60    let context_files =
 61        ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
 62
 63    let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
 64    step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
 65
 66    example.context = Some(ExampleContext {
 67        files: context_files,
 68    });
 69    Ok(())
 70}
 71
 72async fn wait_for_language_servers_to_start(
 73    project: &Entity<Project>,
 74    buffer: &Entity<Buffer>,
 75    step_progress: &Arc<StepProgress>,
 76    cx: &mut AsyncApp,
 77) -> anyhow::Result<()> {
 78    let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
 79
 80    let (language_server_ids, mut starting_language_server_ids) =
 81        buffer.update(cx, |buffer, cx| {
 82            lsp_store.update(cx, |lsp_store, cx| {
 83                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
 84                let starting_ids = ids
 85                    .iter()
 86                    .copied()
 87                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
 88                    .collect::<HashSet<_>>();
 89                (ids, starting_ids)
 90            })
 91        });
 92
 93    step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
 94
 95    let timeout = cx
 96        .background_executor()
 97        .timer(Duration::from_secs(60 * 5))
 98        .shared();
 99
100    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
101    let added_subscription = cx.subscribe(project, {
102        let step_progress = step_progress.clone();
103        move |_, event, _| match event {
104            project::Event::LanguageServerAdded(language_server_id, name, _) => {
105                step_progress.set_substatus(format!("LSP started: {}", name));
106                tx.try_send(*language_server_id).ok();
107            }
108            _ => {}
109        }
110    });
111
112    while !starting_language_server_ids.is_empty() {
113        futures::select! {
114            language_server_id = rx.next() => {
115                if let Some(id) = language_server_id {
116                    starting_language_server_ids.remove(&id);
117                }
118            },
119            _ = timeout.clone().fuse() => {
120                return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
121            }
122        }
123    }
124
125    drop(added_subscription);
126
127    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
128    let subscriptions = [
129        cx.subscribe(&lsp_store, {
130            let step_progress = step_progress.clone();
131            move |_, event, _| {
132                if let project::LspStoreEvent::LanguageServerUpdate {
133                    message:
134                        client::proto::update_language_server::Variant::WorkProgress(
135                            client::proto::LspWorkProgress {
136                                message: Some(message),
137                                ..
138                            },
139                        ),
140                    ..
141                } = event
142                {
143                    step_progress.set_substatus(message.clone());
144                }
145            }
146        }),
147        cx.subscribe(project, {
148            let step_progress = step_progress.clone();
149            let lsp_store = lsp_store.clone();
150            move |_, event, cx| match event {
151                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
152                    let lsp_store = lsp_store.read(cx);
153                    let name = lsp_store
154                        .language_server_adapter_for_id(*language_server_id)
155                        .unwrap()
156                        .name();
157                    step_progress.set_substatus(format!("LSP idle: {}", name));
158                    tx.try_send(*language_server_id).ok();
159                }
160                _ => {}
161            }
162        }),
163    ];
164
165    project
166        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
167        .await?;
168
169    let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
170        language_server_ids
171            .iter()
172            .copied()
173            .filter(|id| {
174                lsp_store
175                    .language_server_statuses
176                    .get(id)
177                    .is_some_and(|status| status.has_pending_diagnostic_updates)
178            })
179            .collect::<HashSet<_>>()
180    });
181    while !pending_language_server_ids.is_empty() {
182        futures::select! {
183            language_server_id = rx.next() => {
184                if let Some(id) = language_server_id {
185                    pending_language_server_ids.remove(&id);
186                }
187            },
188            _ = timeout.clone().fuse() => {
189                return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
190            }
191        }
192    }
193
194    drop(subscriptions);
195    step_progress.clear_substatus();
196    Ok(())
197}