retrieve_context.rs

  1use crate::{
  2    example::Example,
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5    progress::{InfoStyle, Progress, Step, StepProgress},
  6};
  7use anyhow::Context as _;
  8use collections::HashSet;
  9use edit_prediction::{DebugEvent, EditPredictionStore};
 10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 11use gpui::{AsyncApp, Entity};
 12use language::Buffer;
 13use project::Project;
 14use std::sync::Arc;
 15use std::time::Duration;
 16
 17pub async fn run_context_retrieval(
 18    example: &mut Example,
 19    app_state: Arc<EpAppState>,
 20    mut cx: AsyncApp,
 21) -> anyhow::Result<()> {
 22    if example
 23        .prompt_inputs
 24        .as_ref()
 25        .is_some_and(|inputs| inputs.related_files.is_some())
 26    {
 27        return Ok(());
 28    }
 29
 30    run_load_project(example, app_state.clone(), cx.clone()).await?;
 31
 32    let step_progress: Arc<StepProgress> = Progress::global()
 33        .start(Step::Context, &example.spec.name)
 34        .into();
 35
 36    let state = example.state.as_ref().unwrap();
 37    let project = state.project.clone();
 38
 39    let _lsp_handle = project.update(&mut cx, |project, cx| {
 40        project.register_buffer_with_language_servers(&state.buffer, cx)
 41    });
 42    wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
 43
 44    let ep_store = cx
 45        .update(|cx| EditPredictionStore::try_global(cx))
 46        .context("EditPredictionStore not initialized")?;
 47
 48    let mut events = ep_store.update(&mut cx, |store, cx| {
 49        store.register_buffer(&state.buffer, &project, cx);
 50        store.set_use_context(true);
 51        store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 52        store.debug_info(&project, cx)
 53    });
 54
 55    while let Some(event) = events.next().await {
 56        match event {
 57            DebugEvent::ContextRetrievalFinished(_) => {
 58                break;
 59            }
 60            _ => {}
 61        }
 62    }
 63
 64    let context_files =
 65        ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
 66
 67    let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
 68    step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
 69
 70    if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
 71        prompt_inputs.related_files = Some(context_files);
 72    }
 73    Ok(())
 74}
 75
 76async fn wait_for_language_servers_to_start(
 77    project: &Entity<Project>,
 78    buffer: &Entity<Buffer>,
 79    step_progress: &Arc<StepProgress>,
 80    cx: &mut AsyncApp,
 81) -> anyhow::Result<()> {
 82    let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
 83
 84    let (language_server_ids, mut starting_language_server_ids) =
 85        buffer.update(cx, |buffer, cx| {
 86            lsp_store.update(cx, |lsp_store, cx| {
 87                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
 88                let starting_ids = ids
 89                    .iter()
 90                    .copied()
 91                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
 92                    .collect::<HashSet<_>>();
 93                (ids, starting_ids)
 94            })
 95        });
 96
 97    step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
 98
 99    let timeout = cx
100        .background_executor()
101        .timer(Duration::from_secs(60 * 5))
102        .shared();
103
104    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
105    let added_subscription = cx.subscribe(project, {
106        let step_progress = step_progress.clone();
107        move |_, event, _| match event {
108            project::Event::LanguageServerAdded(language_server_id, name, _) => {
109                step_progress.set_substatus(format!("LSP started: {}", name));
110                tx.try_send(*language_server_id).ok();
111            }
112            _ => {}
113        }
114    });
115
116    while !starting_language_server_ids.is_empty() {
117        futures::select! {
118            language_server_id = rx.next() => {
119                if let Some(id) = language_server_id {
120                    starting_language_server_ids.remove(&id);
121                }
122            },
123            _ = timeout.clone().fuse() => {
124                return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
125            }
126        }
127    }
128
129    drop(added_subscription);
130
131    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
132    let subscriptions = [
133        cx.subscribe(&lsp_store, {
134            let step_progress = step_progress.clone();
135            move |_, event, _| {
136                if let project::LspStoreEvent::LanguageServerUpdate {
137                    message:
138                        client::proto::update_language_server::Variant::WorkProgress(
139                            client::proto::LspWorkProgress {
140                                message: Some(message),
141                                ..
142                            },
143                        ),
144                    ..
145                } = event
146                {
147                    step_progress.set_substatus(message.clone());
148                }
149            }
150        }),
151        cx.subscribe(project, {
152            let step_progress = step_progress.clone();
153            let lsp_store = lsp_store.clone();
154            move |_, event, cx| match event {
155                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
156                    let lsp_store = lsp_store.read(cx);
157                    let name = lsp_store
158                        .language_server_adapter_for_id(*language_server_id)
159                        .unwrap()
160                        .name();
161                    step_progress.set_substatus(format!("LSP idle: {}", name));
162                    tx.try_send(*language_server_id).ok();
163                }
164                _ => {}
165            }
166        }),
167    ];
168
169    project
170        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
171        .await?;
172
173    let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
174        language_server_ids
175            .iter()
176            .copied()
177            .filter(|id| {
178                lsp_store
179                    .language_server_statuses
180                    .get(id)
181                    .is_some_and(|status| status.has_pending_diagnostic_updates)
182            })
183            .collect::<HashSet<_>>()
184    });
185    while !pending_language_server_ids.is_empty() {
186        futures::select! {
187            language_server_id = rx.next() => {
188                if let Some(id) = language_server_id {
189                    pending_language_server_ids.remove(&id);
190                }
191            },
192            _ = timeout.clone().fuse() => {
193                return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
194            }
195        }
196    }
197
198    drop(subscriptions);
199    step_progress.clear_substatus();
200    Ok(())
201}