retrieve_context.rs

  1use crate::{
  2    example::{Example, ExampleContext},
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5};
  6use anyhow::Result;
  7use collections::HashSet;
  8use edit_prediction::{DebugEvent, EditPredictionStore};
  9use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 10use gpui::{AsyncApp, Entity, Task};
 11use language::Buffer;
 12use project::Project;
 13use std::{sync::Arc, time::Duration};
 14
 15pub async fn run_context_retrieval(
 16    example: &mut Example,
 17    app_state: Arc<EpAppState>,
 18    mut cx: AsyncApp,
 19) {
 20    if example.context.is_some() {
 21        return;
 22    }
 23
 24    run_load_project(example, app_state.clone(), cx.clone()).await;
 25
 26    let state = example.state.as_ref().unwrap();
 27    let project = state.project.clone();
 28
 29    let _lsp_handle = project
 30        .update(&mut cx, |project, cx| {
 31            project.register_buffer_with_language_servers(&state.buffer, cx)
 32        })
 33        .unwrap();
 34
 35    wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
 36
 37    let ep_store = cx
 38        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 39        .unwrap();
 40
 41    let mut events = ep_store
 42        .update(&mut cx, |store, cx| {
 43            store.register_buffer(&state.buffer, &project, cx);
 44            store.set_use_context(true);
 45            store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 46            store.debug_info(&project, cx)
 47        })
 48        .unwrap();
 49
 50    while let Some(event) = events.next().await {
 51        match event {
 52            DebugEvent::ContextRetrievalFinished(_) => {
 53                break;
 54            }
 55            _ => {}
 56        }
 57    }
 58
 59    let context_files = ep_store
 60        .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
 61        .unwrap();
 62
 63    example.context = Some(ExampleContext {
 64        files: context_files,
 65    });
 66}
 67
 68async fn wait_for_language_server_to_start(
 69    example: &Example,
 70    project: &Entity<Project>,
 71    buffer: &Entity<Buffer>,
 72    cx: &mut AsyncApp,
 73) {
 74    let Some(language_id) = buffer
 75        .read_with(cx, |buffer, _cx| {
 76            buffer.language().map(|language| language.id())
 77        })
 78        .unwrap()
 79    else {
 80        panic!("No language for {:?}", example.cursor_path);
 81    };
 82
 83    let mut ready_languages = HashSet::default();
 84    let log_prefix = format!("{} | ", example.name);
 85    if !ready_languages.contains(&language_id) {
 86        wait_for_lang_server(&project, &buffer, log_prefix, cx)
 87            .await
 88            .unwrap();
 89        ready_languages.insert(language_id);
 90    }
 91
 92    let lsp_store = project
 93        .read_with(cx, |project, _cx| project.lsp_store())
 94        .unwrap();
 95
 96    // hacky wait for buffer to be registered with the language server
 97    for _ in 0..100 {
 98        if lsp_store
 99            .update(cx, |lsp_store, cx| {
100                buffer.update(cx, |buffer, cx| {
101                    lsp_store
102                        .language_servers_for_local_buffer(&buffer, cx)
103                        .next()
104                        .map(|(_, language_server)| language_server.server_id())
105                })
106            })
107            .unwrap()
108            .is_some()
109        {
110            return;
111        } else {
112            cx.background_executor()
113                .timer(Duration::from_millis(10))
114                .await;
115        }
116    }
117
118    panic!("No language server found for buffer");
119}
120
121pub fn wait_for_lang_server(
122    project: &Entity<Project>,
123    buffer: &Entity<Buffer>,
124    log_prefix: String,
125    cx: &mut AsyncApp,
126) -> Task<Result<()>> {
127    eprintln!("{}⏵ Waiting for language server", log_prefix);
128
129    let (mut tx, mut rx) = mpsc::channel(1);
130
131    let lsp_store = project
132        .read_with(cx, |project, _| project.lsp_store())
133        .unwrap();
134
135    let has_lang_server = buffer
136        .update(cx, |buffer, cx| {
137            lsp_store.update(cx, |lsp_store, cx| {
138                lsp_store
139                    .language_servers_for_local_buffer(buffer, cx)
140                    .next()
141                    .is_some()
142            })
143        })
144        .unwrap_or(false);
145
146    if has_lang_server {
147        project
148            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
149            .unwrap()
150            .detach();
151    }
152    let (mut added_tx, mut added_rx) = mpsc::channel(1);
153
154    let subscriptions = [
155        cx.subscribe(&lsp_store, {
156            let log_prefix = log_prefix.clone();
157            move |_, event, _| {
158                if let project::LspStoreEvent::LanguageServerUpdate {
159                    message:
160                        client::proto::update_language_server::Variant::WorkProgress(
161                            client::proto::LspWorkProgress {
162                                message: Some(message),
163                                ..
164                            },
165                        ),
166                    ..
167                } = event
168                {
169                    eprintln!("{}{message}", log_prefix)
170                }
171            }
172        }),
173        cx.subscribe(project, {
174            let buffer = buffer.clone();
175            move |project, event, cx| match event {
176                project::Event::LanguageServerAdded(_, _, _) => {
177                    let buffer = buffer.clone();
178                    project
179                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
180                        .detach();
181                    added_tx.try_send(()).ok();
182                }
183                project::Event::DiskBasedDiagnosticsFinished { .. } => {
184                    tx.try_send(()).ok();
185                }
186                _ => {}
187            }
188        }),
189    ];
190
191    cx.spawn(async move |cx| {
192        if !has_lang_server {
193            // some buffers never have a language server, so this aborts quickly in that case.
194            let timeout = cx.background_executor().timer(Duration::from_secs(500));
195            futures::select! {
196                _ = added_rx.next() => {},
197                _ = timeout.fuse() => {
198                    anyhow::bail!("Waiting for language server add timed out after 5 seconds");
199                }
200            };
201        }
202        let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
203        let result = futures::select! {
204            _ = rx.next() => {
205                eprintln!("{}⚑ Language server idle", log_prefix);
206                anyhow::Ok(())
207            },
208            _ = timeout.fuse() => {
209                anyhow::bail!("LSP wait timed out after 5 minutes");
210            }
211        };
212        drop(subscriptions);
213        result
214    })
215}