retrieve_context.rs

  1use crate::{
  2    example::{Example, ExampleContext},
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5    progress::{InfoStyle, Progress, Step, StepProgress},
  6};
  7use collections::HashSet;
  8use edit_prediction::{DebugEvent, EditPredictionStore};
  9use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 10use gpui::{AsyncApp, Entity};
 11use language::Buffer;
 12use project::Project;
 13use std::sync::Arc;
 14use std::time::Duration;
 15
 16pub async fn run_context_retrieval(
 17    example: &mut Example,
 18    app_state: Arc<EpAppState>,
 19    mut cx: AsyncApp,
 20) {
 21    if example.context.is_some() {
 22        return;
 23    }
 24
 25    run_load_project(example, app_state.clone(), cx.clone()).await;
 26
 27    let step_progress: Arc<StepProgress> = Progress::global()
 28        .start(Step::Context, &example.name)
 29        .into();
 30
 31    let state = example.state.as_ref().unwrap();
 32    let project = state.project.clone();
 33
 34    let _lsp_handle = project
 35        .update(&mut cx, |project, cx| {
 36            project.register_buffer_with_language_servers(&state.buffer, cx)
 37        })
 38        .unwrap();
 39    wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await;
 40
 41    let ep_store = cx
 42        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 43        .unwrap();
 44
 45    let mut events = ep_store
 46        .update(&mut cx, |store, cx| {
 47            store.register_buffer(&state.buffer, &project, cx);
 48            store.set_use_context(true);
 49            store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 50            store.debug_info(&project, cx)
 51        })
 52        .unwrap();
 53
 54    while let Some(event) = events.next().await {
 55        match event {
 56            DebugEvent::ContextRetrievalFinished(_) => {
 57                break;
 58            }
 59            _ => {}
 60        }
 61    }
 62
 63    let context_files = ep_store
 64        .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
 65        .unwrap();
 66
 67    let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
 68    step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
 69
 70    example.context = Some(ExampleContext {
 71        files: context_files,
 72    });
 73}
 74
 75async fn wait_for_language_servers_to_start(
 76    project: &Entity<Project>,
 77    buffer: &Entity<Buffer>,
 78    step_progress: &Arc<StepProgress>,
 79    cx: &mut AsyncApp,
 80) {
 81    let lsp_store = project
 82        .read_with(cx, |project, _| project.lsp_store())
 83        .unwrap();
 84
 85    let (language_server_ids, mut starting_language_server_ids) = buffer
 86        .update(cx, |buffer, cx| {
 87            lsp_store.update(cx, |lsp_store, cx| {
 88                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
 89                let starting_ids = ids
 90                    .iter()
 91                    .copied()
 92                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
 93                    .collect::<HashSet<_>>();
 94                (ids, starting_ids)
 95            })
 96        })
 97        .unwrap_or_default();
 98
 99    step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
100
101    let timeout = cx
102        .background_executor()
103        .timer(Duration::from_secs(60 * 5))
104        .shared();
105
106    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
107    let added_subscription = cx.subscribe(project, {
108        let step_progress = step_progress.clone();
109        move |_, event, _| match event {
110            project::Event::LanguageServerAdded(language_server_id, name, _) => {
111                step_progress.set_substatus(format!("LSP started: {}", name));
112                tx.try_send(*language_server_id).ok();
113            }
114            _ => {}
115        }
116    });
117
118    while !starting_language_server_ids.is_empty() {
119        futures::select! {
120            language_server_id = rx.next() => {
121                if let Some(id) = language_server_id {
122                    starting_language_server_ids.remove(&id);
123                }
124            },
125            _ = timeout.clone().fuse() => {
126                panic!("LSP wait timed out after 5 minutes");
127            }
128        }
129    }
130
131    drop(added_subscription);
132
133    if !language_server_ids.is_empty() {
134        project
135            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
136            .unwrap()
137            .detach();
138    }
139
140    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
141    let subscriptions = [
142        cx.subscribe(&lsp_store, {
143            let step_progress = step_progress.clone();
144            move |_, event, _| {
145                if let project::LspStoreEvent::LanguageServerUpdate {
146                    message:
147                        client::proto::update_language_server::Variant::WorkProgress(
148                            client::proto::LspWorkProgress {
149                                message: Some(message),
150                                ..
151                            },
152                        ),
153                    ..
154                } = event
155                {
156                    step_progress.set_substatus(message.clone());
157                }
158            }
159        }),
160        cx.subscribe(project, {
161            let step_progress = step_progress.clone();
162            move |_, event, cx| match event {
163                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
164                    let lsp_store = lsp_store.read(cx);
165                    let name = lsp_store
166                        .language_server_adapter_for_id(*language_server_id)
167                        .unwrap()
168                        .name();
169                    step_progress.set_substatus(format!("LSP idle: {}", name));
170                    tx.try_send(*language_server_id).ok();
171                }
172                _ => {}
173            }
174        }),
175    ];
176
177    project
178        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
179        .unwrap()
180        .await
181        .unwrap();
182
183    let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
184    while !pending_language_server_ids.is_empty() {
185        futures::select! {
186            language_server_id = rx.next() => {
187                if let Some(id) = language_server_id {
188                    pending_language_server_ids.remove(&id);
189                }
190            },
191            _ = timeout.clone().fuse() => {
192                panic!("LSP wait timed out after 5 minutes");
193            }
194        }
195    }
196
197    drop(subscriptions);
198    step_progress.clear_substatus();
199}