retrieve_context.rs

  1use crate::{
  2    example::{Example, ExampleContext},
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5    progress::{InfoStyle, Progress, Step, StepProgress},
  6};
  7use collections::HashSet;
  8use edit_prediction::{DebugEvent, EditPredictionStore};
  9use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 10use gpui::{AsyncApp, Entity};
 11use language::Buffer;
 12use project::Project;
 13use std::sync::Arc;
 14use std::time::Duration;
 15
 16pub async fn run_context_retrieval(
 17    example: &mut Example,
 18    app_state: Arc<EpAppState>,
 19    progress: Arc<Progress>,
 20    mut cx: AsyncApp,
 21) {
 22    if example.context.is_some() {
 23        return;
 24    }
 25
 26    run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
 27
 28    let step_progress = progress.start(Step::Context, &example.name);
 29
 30    let state = example.state.as_ref().unwrap();
 31    let project = state.project.clone();
 32
 33    let _lsp_handle = project
 34        .update(&mut cx, |project, cx| {
 35            project.register_buffer_with_language_servers(&state.buffer, cx)
 36        })
 37        .unwrap();
 38    wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await;
 39
 40    let ep_store = cx
 41        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 42        .unwrap();
 43
 44    let mut events = ep_store
 45        .update(&mut cx, |store, cx| {
 46            store.register_buffer(&state.buffer, &project, cx);
 47            store.set_use_context(true);
 48            store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 49            store.debug_info(&project, cx)
 50        })
 51        .unwrap();
 52
 53    while let Some(event) = events.next().await {
 54        match event {
 55            DebugEvent::ContextRetrievalFinished(_) => {
 56                break;
 57            }
 58            _ => {}
 59        }
 60    }
 61
 62    let context_files = ep_store
 63        .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
 64        .unwrap();
 65
 66    let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
 67    step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
 68
 69    example.context = Some(ExampleContext {
 70        files: context_files,
 71    });
 72}
 73
 74async fn wait_for_language_servers_to_start(
 75    project: &Entity<Project>,
 76    buffer: &Entity<Buffer>,
 77    step_progress: &Arc<StepProgress>,
 78    cx: &mut AsyncApp,
 79) {
 80    let lsp_store = project
 81        .read_with(cx, |project, _| project.lsp_store())
 82        .unwrap();
 83
 84    let (language_server_ids, mut starting_language_server_ids) = buffer
 85        .update(cx, |buffer, cx| {
 86            lsp_store.update(cx, |lsp_store, cx| {
 87                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
 88                let starting_ids = ids
 89                    .iter()
 90                    .copied()
 91                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
 92                    .collect::<HashSet<_>>();
 93                (ids, starting_ids)
 94            })
 95        })
 96        .unwrap_or_default();
 97
 98    step_progress.set_substatus(format!("waiting for {} LSPs", language_server_ids.len()));
 99
100    let timeout = cx
101        .background_executor()
102        .timer(Duration::from_secs(60 * 5))
103        .shared();
104
105    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
106    let added_subscription = cx.subscribe(project, {
107        let step_progress = step_progress.clone();
108        move |_, event, _| match event {
109            project::Event::LanguageServerAdded(language_server_id, name, _) => {
110                step_progress.set_substatus(format!("LSP started: {}", name));
111                tx.try_send(*language_server_id).ok();
112            }
113            _ => {}
114        }
115    });
116
117    while !starting_language_server_ids.is_empty() {
118        futures::select! {
119            language_server_id = rx.next() => {
120                if let Some(id) = language_server_id {
121                    starting_language_server_ids.remove(&id);
122                }
123            },
124            _ = timeout.clone().fuse() => {
125                panic!("LSP wait timed out after 5 minutes");
126            }
127        }
128    }
129
130    drop(added_subscription);
131
132    if !language_server_ids.is_empty() {
133        project
134            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
135            .unwrap()
136            .detach();
137    }
138
139    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
140    let subscriptions = [
141        cx.subscribe(&lsp_store, {
142            let step_progress = step_progress.clone();
143            move |_, event, _| {
144                if let project::LspStoreEvent::LanguageServerUpdate {
145                    message:
146                        client::proto::update_language_server::Variant::WorkProgress(
147                            client::proto::LspWorkProgress {
148                                message: Some(message),
149                                ..
150                            },
151                        ),
152                    ..
153                } = event
154                {
155                    step_progress.set_substatus(message.clone());
156                }
157            }
158        }),
159        cx.subscribe(project, {
160            let step_progress = step_progress.clone();
161            move |_, event, cx| match event {
162                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
163                    let lsp_store = lsp_store.read(cx);
164                    let name = lsp_store
165                        .language_server_adapter_for_id(*language_server_id)
166                        .unwrap()
167                        .name();
168                    step_progress.set_substatus(format!("LSP idle: {}", name));
169                    tx.try_send(*language_server_id).ok();
170                }
171                _ => {}
172            }
173        }),
174    ];
175
176    project
177        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
178        .unwrap()
179        .await
180        .unwrap();
181
182    let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
183    while !pending_language_server_ids.is_empty() {
184        futures::select! {
185            language_server_id = rx.next() => {
186                if let Some(id) = language_server_id {
187                    pending_language_server_ids.remove(&id);
188                }
189            },
190            _ = timeout.clone().fuse() => {
191                panic!("LSP wait timed out after 5 minutes");
192            }
193        }
194    }
195
196    drop(subscriptions);
197    step_progress.clear_substatus();
198}