1use crate::{
2 example::{Example, ExampleContext},
3 headless::EpAppState,
4 load_project::run_load_project,
5};
6use anyhow::Result;
7use collections::HashSet;
8use edit_prediction::{DebugEvent, EditPredictionStore};
9use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
10use gpui::{AsyncApp, Entity, Task};
11use language::Buffer;
12use project::Project;
13use std::{sync::Arc, time::Duration};
14
15pub async fn run_context_retrieval(
16 example: &mut Example,
17 app_state: Arc<EpAppState>,
18 mut cx: AsyncApp,
19) {
20 if example.context.is_some() {
21 return;
22 }
23
24 run_load_project(example, app_state.clone(), cx.clone()).await;
25
26 let state = example.state.as_ref().unwrap();
27 let project = state.project.clone();
28
29 let _lsp_handle = project
30 .update(&mut cx, |project, cx| {
31 project.register_buffer_with_language_servers(&state.buffer, cx)
32 })
33 .unwrap();
34
35 wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
36
37 let ep_store = cx
38 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
39 .unwrap();
40
41 let mut events = ep_store
42 .update(&mut cx, |store, cx| {
43 store.register_buffer(&state.buffer, &project, cx);
44 store.set_use_context(true);
45 store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
46 store.debug_info(&project, cx)
47 })
48 .unwrap();
49
50 while let Some(event) = events.next().await {
51 match event {
52 DebugEvent::ContextRetrievalFinished(_) => {
53 break;
54 }
55 _ => {}
56 }
57 }
58
59 let context_files = ep_store
60 .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
61 .unwrap();
62
63 example.context = Some(ExampleContext {
64 files: context_files,
65 });
66}
67
68async fn wait_for_language_server_to_start(
69 example: &Example,
70 project: &Entity<Project>,
71 buffer: &Entity<Buffer>,
72 cx: &mut AsyncApp,
73) {
74 let Some(language_id) = buffer
75 .read_with(cx, |buffer, _cx| {
76 buffer.language().map(|language| language.id())
77 })
78 .unwrap()
79 else {
80 panic!("No language for {:?}", example.cursor_path);
81 };
82
83 let mut ready_languages = HashSet::default();
84 let log_prefix = format!("{} | ", example.name);
85 if !ready_languages.contains(&language_id) {
86 wait_for_lang_server(&project, &buffer, log_prefix, cx)
87 .await
88 .unwrap();
89 ready_languages.insert(language_id);
90 }
91
92 let lsp_store = project
93 .read_with(cx, |project, _cx| project.lsp_store())
94 .unwrap();
95
96 // hacky wait for buffer to be registered with the language server
97 for _ in 0..100 {
98 if lsp_store
99 .update(cx, |lsp_store, cx| {
100 buffer.update(cx, |buffer, cx| {
101 lsp_store
102 .language_servers_for_local_buffer(&buffer, cx)
103 .next()
104 .map(|(_, language_server)| language_server.server_id())
105 })
106 })
107 .unwrap()
108 .is_some()
109 {
110 return;
111 } else {
112 cx.background_executor()
113 .timer(Duration::from_millis(10))
114 .await;
115 }
116 }
117
118 panic!("No language server found for buffer");
119}
120
121pub fn wait_for_lang_server(
122 project: &Entity<Project>,
123 buffer: &Entity<Buffer>,
124 log_prefix: String,
125 cx: &mut AsyncApp,
126) -> Task<Result<()>> {
127 eprintln!("{}⏵ Waiting for language server", log_prefix);
128
129 let (mut tx, mut rx) = mpsc::channel(1);
130
131 let lsp_store = project
132 .read_with(cx, |project, _| project.lsp_store())
133 .unwrap();
134
135 let has_lang_server = buffer
136 .update(cx, |buffer, cx| {
137 lsp_store.update(cx, |lsp_store, cx| {
138 lsp_store
139 .language_servers_for_local_buffer(buffer, cx)
140 .next()
141 .is_some()
142 })
143 })
144 .unwrap_or(false);
145
146 if has_lang_server {
147 project
148 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
149 .unwrap()
150 .detach();
151 }
152 let (mut added_tx, mut added_rx) = mpsc::channel(1);
153
154 let subscriptions = [
155 cx.subscribe(&lsp_store, {
156 let log_prefix = log_prefix.clone();
157 move |_, event, _| {
158 if let project::LspStoreEvent::LanguageServerUpdate {
159 message:
160 client::proto::update_language_server::Variant::WorkProgress(
161 client::proto::LspWorkProgress {
162 message: Some(message),
163 ..
164 },
165 ),
166 ..
167 } = event
168 {
169 eprintln!("{}⟲ {message}", log_prefix)
170 }
171 }
172 }),
173 cx.subscribe(project, {
174 let buffer = buffer.clone();
175 move |project, event, cx| match event {
176 project::Event::LanguageServerAdded(_, _, _) => {
177 let buffer = buffer.clone();
178 project
179 .update(cx, |project, cx| project.save_buffer(buffer, cx))
180 .detach();
181 added_tx.try_send(()).ok();
182 }
183 project::Event::DiskBasedDiagnosticsFinished { .. } => {
184 tx.try_send(()).ok();
185 }
186 _ => {}
187 }
188 }),
189 ];
190
191 cx.spawn(async move |cx| {
192 if !has_lang_server {
193 // some buffers never have a language server, so this aborts quickly in that case.
194 let timeout = cx.background_executor().timer(Duration::from_secs(500));
195 futures::select! {
196 _ = added_rx.next() => {},
197 _ = timeout.fuse() => {
198 anyhow::bail!("Waiting for language server add timed out after 5 seconds");
199 }
200 };
201 }
202 let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
203 let result = futures::select! {
204 _ = rx.next() => {
205 eprintln!("{}⚑ Language server idle", log_prefix);
206 anyhow::Ok(())
207 },
208 _ = timeout.fuse() => {
209 anyhow::bail!("LSP wait timed out after 5 minutes");
210 }
211 };
212 drop(subscriptions);
213 result
214 })
215}