retrieve_context.rs

  1use crate::{
  2    example::{Example, ExampleContext},
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5};
  6use collections::HashSet;
  7use edit_prediction::{DebugEvent, EditPredictionStore};
  8use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
  9use gpui::{AsyncApp, Entity};
 10use language::Buffer;
 11use project::Project;
 12use std::{sync::Arc, time::Duration};
 13
 14pub async fn run_context_retrieval(
 15    example: &mut Example,
 16    app_state: Arc<EpAppState>,
 17    mut cx: AsyncApp,
 18) {
 19    if example.context.is_some() {
 20        return;
 21    }
 22
 23    run_load_project(example, app_state.clone(), cx.clone()).await;
 24
 25    let state = example.state.as_ref().unwrap();
 26    let project = state.project.clone();
 27
 28    let _lsp_handle = project
 29        .update(&mut cx, |project, cx| {
 30            project.register_buffer_with_language_servers(&state.buffer, cx)
 31        })
 32        .unwrap();
 33    wait_for_language_servers_to_start(example, &project, &state.buffer, &mut cx).await;
 34
 35    let ep_store = cx
 36        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 37        .unwrap();
 38
 39    let mut events = ep_store
 40        .update(&mut cx, |store, cx| {
 41            store.register_buffer(&state.buffer, &project, cx);
 42            store.set_use_context(true);
 43            store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 44            store.debug_info(&project, cx)
 45        })
 46        .unwrap();
 47
 48    while let Some(event) = events.next().await {
 49        match event {
 50            DebugEvent::ContextRetrievalFinished(_) => {
 51                break;
 52            }
 53            _ => {}
 54        }
 55    }
 56
 57    let context_files = ep_store
 58        .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
 59        .unwrap();
 60
 61    example.context = Some(ExampleContext {
 62        files: context_files,
 63    });
 64}
 65
 66async fn wait_for_language_servers_to_start(
 67    example: &Example,
 68    project: &Entity<Project>,
 69    buffer: &Entity<Buffer>,
 70    cx: &mut AsyncApp,
 71) {
 72    let log_prefix = format!("{} | ", example.name);
 73
 74    let lsp_store = project
 75        .read_with(cx, |project, _| project.lsp_store())
 76        .unwrap();
 77
 78    let (language_server_ids, mut starting_language_server_ids) = buffer
 79        .update(cx, |buffer, cx| {
 80            lsp_store.update(cx, |lsp_store, cx| {
 81                let ids = lsp_store.language_servers_for_local_buffer(buffer, cx);
 82                let starting_ids = ids
 83                    .iter()
 84                    .copied()
 85                    .filter(|id| !lsp_store.language_server_statuses.contains_key(&id))
 86                    .collect::<HashSet<_>>();
 87                (ids, starting_ids)
 88            })
 89        })
 90        .unwrap_or_default();
 91
 92    eprintln!(
 93        "{}⏵ Waiting for {} language servers",
 94        log_prefix,
 95        language_server_ids.len()
 96    );
 97
 98    let timeout = cx
 99        .background_executor()
100        .timer(Duration::from_secs(60 * 5))
101        .shared();
102
103    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
104    let added_subscription = cx.subscribe(project, {
105        let log_prefix = log_prefix.clone();
106        move |_, event, _| match event {
107            project::Event::LanguageServerAdded(language_server_id, name, _) => {
108                eprintln!("{}+ Language server started: {}", log_prefix, name);
109                tx.try_send(*language_server_id).ok();
110            }
111            _ => {}
112        }
113    });
114
115    while !starting_language_server_ids.is_empty() {
116        futures::select! {
117            language_server_id = rx.next() => {
118                if let Some(id) = language_server_id {
119                    starting_language_server_ids.remove(&id);
120                }
121            },
122            _ = timeout.clone().fuse() => {
123                panic!("LSP wait timed out after 5 minutes");
124            }
125        }
126    }
127
128    drop(added_subscription);
129
130    if !language_server_ids.is_empty() {
131        project
132            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
133            .unwrap()
134            .detach();
135    }
136
137    let (mut tx, mut rx) = mpsc::channel(language_server_ids.len());
138    let subscriptions = [
139        cx.subscribe(&lsp_store, {
140            let log_prefix = log_prefix.clone();
141            move |_, event, _| {
142                if let project::LspStoreEvent::LanguageServerUpdate {
143                    message:
144                        client::proto::update_language_server::Variant::WorkProgress(
145                            client::proto::LspWorkProgress {
146                                message: Some(message),
147                                ..
148                            },
149                        ),
150                    ..
151                } = event
152                {
153                    eprintln!("{}{message}", log_prefix)
154                }
155            }
156        }),
157        cx.subscribe(project, {
158            let log_prefix = log_prefix.clone();
159            move |_, event, cx| match event {
160                project::Event::DiskBasedDiagnosticsFinished { language_server_id } => {
161                    let lsp_store = lsp_store.read(cx);
162                    let name = lsp_store
163                        .language_server_adapter_for_id(*language_server_id)
164                        .unwrap()
165                        .name();
166                    eprintln!("{}⚑ Language server idle: {}", log_prefix, name);
167                    tx.try_send(*language_server_id).ok();
168                }
169                _ => {}
170            }
171        }),
172    ];
173
174    project
175        .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
176        .unwrap()
177        .await
178        .unwrap();
179
180    let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
181    while !pending_language_server_ids.is_empty() {
182        futures::select! {
183            language_server_id = rx.next() => {
184                if let Some(id) = language_server_id {
185                    pending_language_server_ids.remove(&id);
186                }
187            },
188            _ = timeout.clone().fuse() => {
189                panic!("LSP wait timed out after 5 minutes");
190            }
191        }
192    }
193
194    drop(subscriptions);
195}