retrieve_context.rs

  1use crate::{
  2    example::Example,
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
  6};
  7use anyhow::Context as _;
  8use collections::HashSet;
  9use edit_prediction::{DebugEvent, EditPredictionStore};
 10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 11use gpui::{AsyncApp, Entity};
 12use language::Buffer;
 13use project::Project;
 14use std::sync::Arc;
 15use std::time::Duration;
 16
 17pub async fn run_context_retrieval(
 18    example: &mut Example,
 19    app_state: Arc<EpAppState>,
 20    example_progress: &ExampleProgress,
 21    mut cx: AsyncApp,
 22) -> anyhow::Result<()> {
 23    if example.prompt_inputs.is_some() {
 24        if example.spec.repository_url.is_empty() {
 25            return Ok(());
 26        }
 27
 28        if example
 29            .prompt_inputs
 30            .as_ref()
 31            .is_some_and(|inputs| !inputs.related_files.is_empty())
 32        {
 33            return Ok(());
 34        }
 35    }
 36
 37    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 38
 39    let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
 40
 41    let state = example.state.as_ref().unwrap();
 42    let project = state.project.clone();
 43
 44    let _lsp_handle = project.update(&mut cx, |project, cx| {
 45        project.register_buffer_with_language_servers(&state.buffer, cx)
 46    });
 47    wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
 48
 49    let ep_store = cx
 50        .update(|cx| EditPredictionStore::try_global(cx))
 51        .context("EditPredictionStore not initialized")?;
 52
 53    let mut events = ep_store.update(&mut cx, |store, cx| {
 54        store.register_buffer(&state.buffer, &project, cx);
 55        store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 56        store.debug_info(&project, cx)
 57    });
 58
 59    while let Some(event) = events.next().await {
 60        match event {
 61            DebugEvent::ContextRetrievalFinished(_) => {
 62                break;
 63            }
 64            _ => {}
 65        }
 66    }
 67
 68    let context_files =
 69        ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
 70
 71    let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
 72    step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
 73
 74    if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
 75        prompt_inputs.related_files = context_files;
 76    }
 77    Ok(())
 78}
 79
 80async fn wait_for_language_servers_to_start(
 81    project: &Entity<Project>,
 82    buffer: &Entity<Buffer>,
 83    step_progress: &Arc<StepProgress>,
 84    cx: &mut AsyncApp,
 85) -> anyhow::Result<()> {
 86    let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
 87
 88    let (language_server_ids, mut starting_language_server_ids) =
 89        buffer.update(cx, |buffer, cx| {
 90            lsp_store.update(cx, |lsp_store, cx| {
 91                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
 92                let starting_ids = ids
 93                    .iter()
 94                    .copied()
 95                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
 96                    .collect::<HashSet<_>>();
 97                (ids, starting_ids)
 98            })
 99        });
100
101    step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
102
103    let timeout_duration = if starting_language_server_ids.is_empty() {
104        Duration::from_secs(30)
105    } else {
106        Duration::from_secs(60 * 5)
107    };
108
109    let timeout = cx.background_executor().timer(timeout_duration).shared();
110
111    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
112    let added_subscription = cx.subscribe(project, {
113        let step_progress = step_progress.clone();
114        move |_, event, _| match event {
115            project::Event::LanguageServerAdded(language_server_id, name, _) => {
116                step_progress.set_substatus(format!("LSP started: {}", name));
117                tx.try_send(*language_server_id).ok();
118            }
119            _ => {}
120        }
121    });
122
123    while !starting_language_server_ids.is_empty() {
124        futures::select! {
125            language_server_id = rx.next() => {
126                if let Some(id) = language_server_id {
127                    starting_language_server_ids.remove(&id);
128                }
129            },
130            _ = timeout.clone().fuse() => {
131                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
132            }
133        }
134    }
135
136    drop(added_subscription);
137
138    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
139    let subscriptions = [
140        cx.subscribe(&lsp_store, {
141            let step_progress = step_progress.clone();
142            move |_, event, _| {
143                if let project::LspStoreEvent::LanguageServerUpdate {
144                    message:
145                        client::proto::update_language_server::Variant::WorkProgress(
146                            client::proto::LspWorkProgress {
147                                message: Some(message),
148                                ..
149                            },
150                        ),
151                    ..
152                } = event
153                {
154                    step_progress.set_substatus(message.clone());
155                }
156            }
157        }),
158        cx.subscribe(project, {
159            let step_progress = step_progress.clone();
160            let lsp_store = lsp_store.clone();
161            move |_, event, cx| match event {
162                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
163                    let lsp_store = lsp_store.read(cx);
164                    let name = lsp_store
165                        .language_server_adapter_for_id(*language_server_id)
166                        .unwrap()
167                        .name();
168                    step_progress.set_substatus(format!("LSP idle: {}", name));
169                    tx.try_send(*language_server_id).ok();
170                }
171                _ => {}
172            }
173        }),
174    ];
175
176    project
177        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
178        .await?;
179
180    let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
181        language_server_ids
182            .iter()
183            .copied()
184            .filter(|id| {
185                lsp_store
186                    .language_server_statuses
187                    .get(id)
188                    .is_some_and(|status| status.has_pending_diagnostic_updates)
189            })
190            .collect::<HashSet<_>>()
191    });
192    while !pending_language_server_ids.is_empty() {
193        futures::select! {
194            language_server_id = rx.next() => {
195                if let Some(id) = language_server_id {
196                    pending_language_server_ids.remove(&id);
197                }
198            },
199            _ = timeout.clone().fuse() => {
200                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
201            }
202        }
203    }
204
205    drop(subscriptions);
206    step_progress.clear_substatus();
207    Ok(())
208}