retrieve_context.rs

  1use crate::{
  2    example::{Example, ExampleContext},
  3    headless::EpAppState,
  4    load_project::run_load_project,
  5};
  6use anyhow::Result;
  7use collections::HashSet;
  8use edit_prediction::{DebugEvent, EditPredictionStore};
  9use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 10use gpui::{AsyncApp, Entity, Task};
 11use language::{Buffer, LanguageNotFound};
 12use project::Project;
 13use std::{sync::Arc, time::Duration};
 14
 15pub async fn run_context_retrieval(
 16    example: &mut Example,
 17    app_state: Arc<EpAppState>,
 18    mut cx: AsyncApp,
 19) {
 20    if example.context.is_some() {
 21        return;
 22    }
 23
 24    run_load_project(example, app_state.clone(), cx.clone()).await;
 25
 26    let state = example.state.as_ref().unwrap();
 27    let project = state.project.clone();
 28
 29    let _lsp_handle = project
 30        .update(&mut cx, |project, cx| {
 31            project.register_buffer_with_language_servers(&state.buffer, cx)
 32        })
 33        .unwrap();
 34
 35    wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
 36
 37    let ep_store = cx
 38        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 39        .unwrap();
 40
 41    let mut events = ep_store
 42        .update(&mut cx, |store, cx| {
 43            store.register_buffer(&state.buffer, &project, cx);
 44            store.set_use_context(true);
 45            store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
 46            store.debug_info(&project, cx)
 47        })
 48        .unwrap();
 49
 50    while let Some(event) = events.next().await {
 51        match event {
 52            DebugEvent::ContextRetrievalFinished(_) => {
 53                break;
 54            }
 55            _ => {}
 56        }
 57    }
 58
 59    let context_files = ep_store
 60        .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
 61        .unwrap();
 62
 63    example.context = Some(ExampleContext {
 64        files: context_files,
 65    });
 66}
 67
 68async fn wait_for_language_server_to_start(
 69    example: &Example,
 70    project: &Entity<Project>,
 71    buffer: &Entity<Buffer>,
 72    cx: &mut AsyncApp,
 73) {
 74    let language_registry = project
 75        .read_with(cx, |project, _| project.languages().clone())
 76        .unwrap();
 77    let result = language_registry
 78        .load_language_for_file_path(&example.cursor_path)
 79        .await;
 80
 81    if let Err(error) = result
 82        && !error.is::<LanguageNotFound>()
 83    {
 84        panic!("Failed to load language for file path: {}", error);
 85    }
 86
 87    let Some(language_id) = buffer
 88        .read_with(cx, |buffer, _cx| {
 89            buffer.language().map(|language| language.id())
 90        })
 91        .unwrap()
 92    else {
 93        panic!("No language for {:?}", example.cursor_path);
 94    };
 95
 96    let mut ready_languages = HashSet::default();
 97    let log_prefix = format!("{} | ", example.name);
 98    if !ready_languages.contains(&language_id) {
 99        wait_for_lang_server(&project, &buffer, log_prefix, cx)
100            .await
101            .unwrap();
102        ready_languages.insert(language_id);
103    }
104
105    let lsp_store = project
106        .read_with(cx, |project, _cx| project.lsp_store())
107        .unwrap();
108
109    // hacky wait for buffer to be registered with the language server
110    for _ in 0..100 {
111        if lsp_store
112            .update(cx, |lsp_store, cx| {
113                buffer.update(cx, |buffer, cx| {
114                    lsp_store
115                        .language_servers_for_local_buffer(&buffer, cx)
116                        .next()
117                        .map(|(_, language_server)| language_server.server_id())
118                })
119            })
120            .unwrap()
121            .is_some()
122        {
123            return;
124        } else {
125            cx.background_executor()
126                .timer(Duration::from_millis(10))
127                .await;
128        }
129    }
130
131    panic!("No language server found for buffer");
132}
133
134pub fn wait_for_lang_server(
135    project: &Entity<Project>,
136    buffer: &Entity<Buffer>,
137    log_prefix: String,
138    cx: &mut AsyncApp,
139) -> Task<Result<()>> {
140    eprintln!("{}⏵ Waiting for language server", log_prefix);
141
142    let (mut tx, mut rx) = mpsc::channel(1);
143
144    let lsp_store = project
145        .read_with(cx, |project, _| project.lsp_store())
146        .unwrap();
147
148    let has_lang_server = buffer
149        .update(cx, |buffer, cx| {
150            lsp_store.update(cx, |lsp_store, cx| {
151                lsp_store
152                    .language_servers_for_local_buffer(buffer, cx)
153                    .next()
154                    .is_some()
155            })
156        })
157        .unwrap_or(false);
158
159    if has_lang_server {
160        project
161            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
162            .unwrap()
163            .detach();
164    }
165    let (mut added_tx, mut added_rx) = mpsc::channel(1);
166
167    let subscriptions = [
168        cx.subscribe(&lsp_store, {
169            let log_prefix = log_prefix.clone();
170            move |_, event, _| {
171                if let project::LspStoreEvent::LanguageServerUpdate {
172                    message:
173                        client::proto::update_language_server::Variant::WorkProgress(
174                            client::proto::LspWorkProgress {
175                                message: Some(message),
176                                ..
177                            },
178                        ),
179                    ..
180                } = event
181                {
182                    eprintln!("{}{message}", log_prefix)
183                }
184            }
185        }),
186        cx.subscribe(project, {
187            let buffer = buffer.clone();
188            move |project, event, cx| match event {
189                project::Event::LanguageServerAdded(_, _, _) => {
190                    let buffer = buffer.clone();
191                    project
192                        .update(cx, |project, cx| project.save_buffer(buffer, cx))
193                        .detach();
194                    added_tx.try_send(()).ok();
195                }
196                project::Event::DiskBasedDiagnosticsFinished { .. } => {
197                    tx.try_send(()).ok();
198                }
199                _ => {}
200            }
201        }),
202    ];
203
204    cx.spawn(async move |cx| {
205        if !has_lang_server {
206            // some buffers never have a language server, so this aborts quickly in that case.
207            let timeout = cx.background_executor().timer(Duration::from_secs(500));
208            futures::select! {
209                _ = added_rx.next() => {},
210                _ = timeout.fuse() => {
211                    anyhow::bail!("Waiting for language server add timed out after 5 seconds");
212                }
213            };
214        }
215        let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
216        let result = futures::select! {
217            _ = rx.next() => {
218                eprintln!("{}⚑ Language server idle", log_prefix);
219                anyhow::Ok(())
220            },
221            _ = timeout.fuse() => {
222                anyhow::bail!("LSP wait timed out after 5 minutes");
223            }
224        };
225        drop(subscriptions);
226        result
227    })
228}