retrieve_context.rs

  1use crate::{
  2    example::{Example, ExampleContext},
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5};
  6use collections::HashSet;
  7use edit_prediction::{DebugEvent, EditPredictionStore};
  8use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
  9use gpui::{AsyncApp, Entity};
 10use language::Buffer;
 11use project::Project;
 12use std::{sync::Arc, time::Duration};
 13
 14pub async fn run_context_retrieval(
 15    example: &mut Example,
 16    app_state: Arc<EpAppState>,
 17    mut cx: AsyncApp,
 18) {
 19    if example.context.is_some() {
 20        return;
 21    }
 22
 23    run_load_project(example, app_state.clone(), cx.clone()).await;
 24
 25    let state = example.state.as_ref().unwrap();
 26    let project = state.project.clone();
 27
 28    let _lsp_handle = project
 29        .update(&mut cx, |project, cx| {
 30            project.register_buffer_with_language_servers(&state.buffer, cx)
 31        })
 32        .unwrap();
 33    wait_for_language_servers_to_start(example, &project, &state.buffer, &mut cx).await;
 34
 35    let ep_store = cx
 36        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 37        .unwrap();
 38
 39    let mut events = ep_store
 40        .update(&mut cx, |store, cx| {
 41            store.register_buffer(&state.buffer, &project, cx);
 42            store.set_use_context(true);
 43            store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 44            store.debug_info(&project, cx)
 45        })
 46        .unwrap();
 47
 48    while let Some(event) = events.next().await {
 49        match event {
 50            DebugEvent::ContextRetrievalFinished(_) => {
 51                break;
 52            }
 53            _ => {}
 54        }
 55    }
 56
 57    let context_files = ep_store
 58        .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
 59        .unwrap();
 60
 61    example.context = Some(ExampleContext {
 62        files: context_files,
 63    });
 64}
 65
 66async fn wait_for_language_servers_to_start(
 67    example: &Example,
 68    project: &Entity<Project>,
 69    buffer: &Entity<Buffer>,
 70    cx: &mut AsyncApp,
 71) {
 72    let log_prefix = format!("{} | ", example.name);
 73
 74    let lsp_store = project
 75        .read_with(cx, |project, _| project.lsp_store())
 76        .unwrap();
 77
 78    let lang_server_ids = buffer
 79        .update(cx, |buffer, cx| {
 80            lsp_store.update(cx, |lsp_store, cx| {
 81                lsp_store.language_servers_for_local_buffer(buffer, cx)
 82            })
 83        })
 84        .unwrap_or_default();
 85
 86    if !lang_server_ids.is_empty() {
 87        project
 88            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
 89            .unwrap()
 90            .detach();
 91    }
 92
 93    eprintln!(
 94        "{}⏵ Waiting for {} language servers",
 95        log_prefix,
 96        lang_server_ids.len()
 97    );
 98
 99    let timeout = cx
100        .background_executor()
101        .timer(Duration::from_secs(60 * 5))
102        .shared();
103
104    let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len());
105    let added_subscription = cx.subscribe(project, {
106        let log_prefix = log_prefix.clone();
107        move |_, event, _| match event {
108            project::Event::LanguageServerAdded(language_server_id, name, _) => {
109                eprintln!("{}+ Language server started: {}", log_prefix, name);
110                tx.try_send(*language_server_id).ok();
111            }
112            _ => {}
113        }
114    });
115
116    let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.iter());
117    while !pending_language_server_ids.is_empty() {
118        futures::select! {
119            language_server_id = rx.next() => {
120                if let Some(id) = language_server_id {
121                    pending_language_server_ids.remove(&id);
122                }
123            },
124            _ = timeout.clone().fuse() => {
125                panic!("LSP wait timed out after 5 minutes");
126            }
127        }
128    }
129
130    drop(added_subscription);
131
132    let (mut tx, mut rx) = mpsc::channel(lang_server_ids.len());
133    let subscriptions = [
134        cx.subscribe(&lsp_store, {
135            let log_prefix = log_prefix.clone();
136            move |_, event, _| {
137                if let project::LspStoreEvent::LanguageServerUpdate {
138                    message:
139                        client::proto::update_language_server::Variant::WorkProgress(
140                            client::proto::LspWorkProgress {
141                                message: Some(message),
142                                ..
143                            },
144                        ),
145                    ..
146                } = event
147                {
148                    eprintln!("{}{message}", log_prefix)
149                }
150            }
151        }),
152        cx.subscribe(project, {
153            let log_prefix = log_prefix.clone();
154            move |_, event, cx| match event {
155                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
156                    let lsp_store = lsp_store.read(cx);
157                    let name = lsp_store
158                        .language_server_adapter_for_id(*language_server_id)
159                        .unwrap()
160                        .name();
161                    eprintln!("{}⚑ Language server idle: {}", log_prefix, name);
162                    tx.try_send(*language_server_id).ok();
163                }
164                _ => {}
165            }
166        }),
167    ];
168
169    project
170        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
171        .unwrap()
172        .await
173        .unwrap();
174
175    let mut pending_language_server_ids = HashSet::from_iter(lang_server_ids.into_iter());
176    while !pending_language_server_ids.is_empty() {
177        futures::select! {
178            language_server_id = rx.next() => {
179                if let Some(id) = language_server_id {
180                    pending_language_server_ids.remove(&id);
181                }
182            },
183            _ = timeout.clone().fuse() => {
184                panic!("LSP wait timed out after 5 minutes");
185            }
186        }
187    }
188
189    drop(subscriptions);
190}