retrieve_context.rs

  1use crate::{
  2    example::Example,
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
  6};
  7use anyhow::Context as _;
  8use collections::HashSet;
  9use edit_prediction::{DebugEvent, EditPredictionStore};
 10use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 11use gpui::{AsyncApp, Entity};
 12use language::Buffer;
 13use project::Project;
 14use std::sync::Arc;
 15use std::time::Duration;
 16
 17pub async fn run_context_retrieval(
 18    example: &mut Example,
 19    app_state: Arc<EpAppState>,
 20    example_progress: &ExampleProgress,
 21    mut cx: AsyncApp,
 22) -> anyhow::Result<()> {
 23    if example
 24        .prompt_inputs
 25        .as_ref()
 26        .is_some_and(|inputs| inputs.related_files.is_some())
 27    {
 28        return Ok(());
 29    }
 30
 31    run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
 32
 33    let step_progress: Arc<StepProgress> = example_progress.start(Step::Context).into();
 34
 35    let state = example.state.as_ref().unwrap();
 36    let project = state.project.clone();
 37
 38    let _lsp_handle = project.update(&mut cx, |project, cx| {
 39        project.register_buffer_with_language_servers(&state.buffer, cx)
 40    });
 41    wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
 42
 43    let ep_store = cx
 44        .update(|cx| EditPredictionStore::try_global(cx))
 45        .context("EditPredictionStore not initialized")?;
 46
 47    let mut events = ep_store.update(&mut cx, |store, cx| {
 48        store.register_buffer(&state.buffer, &project, cx);
 49        store.set_use_context(true);
 50        store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 51        store.debug_info(&project, cx)
 52    });
 53
 54    while let Some(event) = events.next().await {
 55        match event {
 56            DebugEvent::ContextRetrievalFinished(_) => {
 57                break;
 58            }
 59            _ => {}
 60        }
 61    }
 62
 63    let context_files =
 64        ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx));
 65
 66    let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
 67    step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
 68
 69    if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
 70        prompt_inputs.related_files = Some(context_files);
 71    }
 72    Ok(())
 73}
 74
 75async fn wait_for_language_servers_to_start(
 76    project: &Entity<Project>,
 77    buffer: &Entity<Buffer>,
 78    step_progress: &Arc<StepProgress>,
 79    cx: &mut AsyncApp,
 80) -> anyhow::Result<()> {
 81    let lsp_store = project.read_with(cx, |project, _| project.lsp_store());
 82
 83    let (language_server_ids, mut starting_language_server_ids) =
 84        buffer.update(cx, |buffer, cx| {
 85            lsp_store.update(cx, |lsp_store, cx| {
 86                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
 87                let starting_ids = ids
 88                    .iter()
 89                    .copied()
 90                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
 91                    .collect::<HashSet<_>>();
 92                (ids, starting_ids)
 93            })
 94        });
 95
 96    step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
 97
 98    let timeout_duration = if starting_language_server_ids.is_empty() {
 99        Duration::from_secs(30)
100    } else {
101        Duration::from_secs(60 * 5)
102    };
103
104    let timeout = cx.background_executor().timer(timeout_duration).shared();
105
106    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
107    let added_subscription = cx.subscribe(project, {
108        let step_progress = step_progress.clone();
109        move |_, event, _| match event {
110            project::Event::LanguageServerAdded(language_server_id, name, _) => {
111                step_progress.set_substatus(format!("LSP started: {}", name));
112                tx.try_send(*language_server_id).ok();
113            }
114            _ => {}
115        }
116    });
117
118    while !starting_language_server_ids.is_empty() {
119        futures::select! {
120            language_server_id = rx.next() => {
121                if let Some(id) = language_server_id {
122                    starting_language_server_ids.remove(&id);
123                }
124            },
125            _ = timeout.clone().fuse() => {
126                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
127            }
128        }
129    }
130
131    drop(added_subscription);
132
133    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
134    let subscriptions = [
135        cx.subscribe(&lsp_store, {
136            let step_progress = step_progress.clone();
137            move |_, event, _| {
138                if let project::LspStoreEvent::LanguageServerUpdate {
139                    message:
140                        client::proto::update_language_server::Variant::WorkProgress(
141                            client::proto::LspWorkProgress {
142                                message: Some(message),
143                                ..
144                            },
145                        ),
146                    ..
147                } = event
148                {
149                    step_progress.set_substatus(message.clone());
150                }
151            }
152        }),
153        cx.subscribe(project, {
154            let step_progress = step_progress.clone();
155            let lsp_store = lsp_store.clone();
156            move |_, event, cx| match event {
157                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
158                    let lsp_store = lsp_store.read(cx);
159                    let name = lsp_store
160                        .language_server_adapter_for_id(*language_server_id)
161                        .unwrap()
162                        .name();
163                    step_progress.set_substatus(format!("LSP idle: {}", name));
164                    tx.try_send(*language_server_id).ok();
165                }
166                _ => {}
167            }
168        }),
169    ];
170
171    project
172        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
173        .await?;
174
175    let mut pending_language_server_ids = lsp_store.read_with(cx, |lsp_store, _| {
176        language_server_ids
177            .iter()
178            .copied()
179            .filter(|id| {
180                lsp_store
181                    .language_server_statuses
182                    .get(id)
183                    .is_some_and(|status| status.has_pending_diagnostic_updates)
184            })
185            .collect::<HashSet<_>>()
186    });
187    while !pending_language_server_ids.is_empty() {
188        futures::select! {
189            language_server_id = rx.next() => {
190                if let Some(id) = language_server_id {
191                    pending_language_server_ids.remove(&id);
192                }
193            },
194            _ = timeout.clone().fuse() => {
195                return Err(anyhow::anyhow!("LSP wait timed out after {} minutes", timeout_duration.as_secs() / 60));
196            }
197        }
198    }
199
200    drop(subscriptions);
201    step_progress.clear_substatus();
202    Ok(())
203}