1use anyhow::{anyhow, Context, Result};
2use gpui::{executor, AppContext, Task};
3use parking_lot::{Mutex, RwLock};
4use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
5use serde::{Deserialize, Serialize};
6use serde_json::{json, value::RawValue, Value};
7use smol::{
8 channel,
9 io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
10 process::Command,
11};
12use std::{
13 collections::HashMap,
14 future::Future,
15 io::Write,
16 str::FromStr,
17 sync::{
18 atomic::{AtomicUsize, Ordering::SeqCst},
19 Arc,
20 },
21};
22use std::{path::Path, process::Stdio};
23use util::TryFutureExt;
24
25const JSON_RPC_VERSION: &'static str = "2.0";
26const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
27
28type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
29type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
30
31pub struct LanguageServer {
32 next_id: AtomicUsize,
33 outbound_tx: channel::Sender<Vec<u8>>,
34 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
35 response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
36 _input_task: Task<Option<()>>,
37 _output_task: Task<Option<()>>,
38 initialized: barrier::Receiver,
39}
40
41pub struct Subscription {
42 method: &'static str,
43 notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
44}
45
46#[derive(Serialize)]
47struct Request<T> {
48 jsonrpc: &'static str,
49 id: usize,
50 method: &'static str,
51 params: T,
52}
53
54#[derive(Deserialize)]
55struct Response<'a> {
56 id: usize,
57 #[serde(default)]
58 error: Option<Error>,
59 #[serde(borrow)]
60 result: &'a RawValue,
61}
62
63#[derive(Serialize)]
64struct OutboundNotification<T> {
65 jsonrpc: &'static str,
66 method: &'static str,
67 params: T,
68}
69
70#[derive(Deserialize)]
71struct InboundNotification<'a> {
72 #[serde(borrow)]
73 method: &'a str,
74 #[serde(borrow)]
75 params: &'a RawValue,
76}
77
78#[derive(Debug, Deserialize)]
79struct Error {
80 message: String,
81}
82
83impl LanguageServer {
84 pub fn rust(root_path: &Path, cx: &AppContext) -> Result<Arc<Self>> {
85 const ZED_BUNDLE: Option<&'static str> = option_env!("ZED_BUNDLE");
86 const ZED_TARGET: &'static str = env!("ZED_TARGET");
87
88 let rust_analyzer_name = format!("rust-analyzer-{}", ZED_TARGET);
89 if ZED_BUNDLE.map_or(Ok(false), |b| b.parse())? {
90 let rust_analyzer_path = cx
91 .platform()
92 .path_for_resource(Some(&rust_analyzer_name), None)?;
93 Self::new(root_path, &rust_analyzer_path, cx.background())
94 } else {
95 Self::new(root_path, Path::new(&rust_analyzer_name), cx.background())
96 }
97 }
98
99 pub fn new(
100 root_path: &Path,
101 server_path: &Path,
102 background: &executor::Background,
103 ) -> Result<Arc<Self>> {
104 let mut server = Command::new(server_path)
105 .stdin(Stdio::piped())
106 .stdout(Stdio::piped())
107 .stderr(Stdio::inherit())
108 .spawn()?;
109 let mut stdin = server.stdin.take().unwrap();
110 let mut stdout = BufReader::new(server.stdout.take().unwrap());
111 let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
112 let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
113 let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
114 let _input_task = background.spawn(
115 {
116 let notification_handlers = notification_handlers.clone();
117 let response_handlers = response_handlers.clone();
118 async move {
119 let mut buffer = Vec::new();
120 loop {
121 buffer.clear();
122
123 stdout.read_until(b'\n', &mut buffer).await?;
124 stdout.read_until(b'\n', &mut buffer).await?;
125 let message_len: usize = std::str::from_utf8(&buffer)?
126 .strip_prefix(CONTENT_LEN_HEADER)
127 .ok_or_else(|| anyhow!("invalid header"))?
128 .trim_end()
129 .parse()?;
130
131 buffer.resize(message_len, 0);
132 stdout.read_exact(&mut buffer).await?;
133
134 if let Ok(InboundNotification { method, params }) =
135 serde_json::from_slice(&buffer)
136 {
137 if let Some(handler) = notification_handlers.read().get(method) {
138 handler(params.get());
139 } else {
140 log::info!(
141 "unhandled notification {}:\n{}",
142 method,
143 serde_json::to_string_pretty(
144 &Value::from_str(params.get()).unwrap()
145 )
146 .unwrap()
147 );
148 }
149 } else if let Ok(Response { id, error, result }) =
150 serde_json::from_slice(&buffer)
151 {
152 if let Some(handler) = response_handlers.lock().remove(&id) {
153 if let Some(error) = error {
154 handler(Err(error));
155 } else {
156 handler(Ok(result.get()));
157 }
158 }
159 } else {
160 return Err(anyhow!(
161 "failed to deserialize message:\n{}",
162 std::str::from_utf8(&buffer)?
163 ));
164 }
165 }
166 }
167 }
168 .log_err(),
169 );
170 let _output_task = background.spawn(
171 async move {
172 let mut content_len_buffer = Vec::new();
173 loop {
174 content_len_buffer.clear();
175
176 let message = outbound_rx.recv().await?;
177 write!(content_len_buffer, "{}", message.len()).unwrap();
178 stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
179 stdin.write_all(&content_len_buffer).await?;
180 stdin.write_all("\r\n\r\n".as_bytes()).await?;
181 stdin.write_all(&message).await?;
182 }
183 }
184 .log_err(),
185 );
186
187 let (initialized_tx, initialized_rx) = barrier::channel();
188 let this = Arc::new(Self {
189 notification_handlers,
190 response_handlers,
191 next_id: Default::default(),
192 outbound_tx,
193 _input_task,
194 _output_task,
195 initialized: initialized_rx,
196 });
197
198 let root_uri =
199 lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
200 background
201 .spawn({
202 let this = this.clone();
203 async move {
204 this.init(root_uri).log_err().await;
205 drop(initialized_tx);
206 }
207 })
208 .detach();
209
210 Ok(this)
211 }
212
213 async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
214 self.request_internal::<lsp_types::request::Initialize>(lsp_types::InitializeParams {
215 process_id: Default::default(),
216 root_path: Default::default(),
217 root_uri: Some(root_uri),
218 initialization_options: Default::default(),
219 capabilities: lsp_types::ClientCapabilities {
220 experimental: Some(json!({
221 "serverStatusNotification": true,
222 })),
223 ..Default::default()
224 },
225 trace: Default::default(),
226 workspace_folders: Default::default(),
227 client_info: Default::default(),
228 locale: Default::default(),
229 })
230 .await?;
231 self.notify_internal::<lsp_types::notification::Initialized>(
232 lsp_types::InitializedParams {},
233 )
234 .await?;
235 Ok(())
236 }
237
238 pub fn on_notification<T, F>(&self, f: F) -> Subscription
239 where
240 T: lsp_types::notification::Notification,
241 F: 'static + Send + Sync + Fn(T::Params),
242 {
243 let prev_handler = self.notification_handlers.write().insert(
244 T::METHOD,
245 Box::new(
246 move |notification| match serde_json::from_str(notification) {
247 Ok(notification) => f(notification),
248 Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
249 },
250 ),
251 );
252
253 assert!(
254 prev_handler.is_none(),
255 "registered multiple handlers for the same notification"
256 );
257
258 Subscription {
259 method: T::METHOD,
260 notification_handlers: self.notification_handlers.clone(),
261 }
262 }
263
264 pub fn request<T: lsp_types::request::Request>(
265 self: Arc<Self>,
266 params: T::Params,
267 ) -> impl Future<Output = Result<T::Result>>
268 where
269 T::Result: 'static + Send,
270 {
271 let this = self.clone();
272 async move {
273 this.initialized.clone().recv().await;
274 this.request_internal::<T>(params).await
275 }
276 }
277
278 fn request_internal<T: lsp_types::request::Request>(
279 self: &Arc<Self>,
280 params: T::Params,
281 ) -> impl Future<Output = Result<T::Result>>
282 where
283 T::Result: 'static + Send,
284 {
285 let id = self.next_id.fetch_add(1, SeqCst);
286 let message = serde_json::to_vec(&Request {
287 jsonrpc: JSON_RPC_VERSION,
288 id,
289 method: T::METHOD,
290 params,
291 })
292 .unwrap();
293 let mut response_handlers = self.response_handlers.lock();
294 let (mut tx, mut rx) = oneshot::channel();
295 response_handlers.insert(
296 id,
297 Box::new(move |result| {
298 let response = match result {
299 Ok(response) => {
300 serde_json::from_str(response).context("failed to deserialize response")
301 }
302 Err(error) => Err(anyhow!("{}", error.message)),
303 };
304 let _ = tx.try_send(response);
305 }),
306 );
307
308 let this = self.clone();
309 async move {
310 this.outbound_tx.send(message).await?;
311 rx.recv().await.unwrap()
312 }
313 }
314
315 pub fn notify<T: lsp_types::notification::Notification>(
316 self: &Arc<Self>,
317 params: T::Params,
318 ) -> impl Future<Output = Result<()>> {
319 let this = self.clone();
320 async move {
321 this.initialized.clone().recv().await;
322 this.notify_internal::<T>(params).await
323 }
324 }
325
326 fn notify_internal<T: lsp_types::notification::Notification>(
327 self: &Arc<Self>,
328 params: T::Params,
329 ) -> impl Future<Output = Result<()>> {
330 let message = serde_json::to_vec(&OutboundNotification {
331 jsonrpc: JSON_RPC_VERSION,
332 method: T::METHOD,
333 params,
334 })
335 .unwrap();
336
337 let this = self.clone();
338 async move {
339 this.outbound_tx.send(message).await?;
340 Ok(())
341 }
342 }
343}
344
345impl Drop for Subscription {
346 fn drop(&mut self) {
347 self.notification_handlers.write().remove(self.method);
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use gpui::TestAppContext;
355 use unindent::Unindent;
356 use util::test::temp_tree;
357
358 #[gpui::test]
359 async fn test_basic(cx: TestAppContext) {
360 let lib_source = r#"
361 fn fun() {
362 let hello = "world";
363 }
364 "#
365 .unindent();
366 let root_dir = temp_tree(json!({
367 "Cargo.toml": r#"
368 [package]
369 name = "temp"
370 version = "0.1.0"
371 edition = "2018"
372 "#.unindent(),
373 "src": {
374 "lib.rs": &lib_source
375 }
376 }));
377 let lib_file_uri =
378 lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
379
380 let server = cx.read(|cx| LanguageServer::rust(root_dir.path(), cx).unwrap());
381 server.next_idle_notification().await;
382
383 server
384 .notify::<lsp_types::notification::DidOpenTextDocument>(
385 lsp_types::DidOpenTextDocumentParams {
386 text_document: lsp_types::TextDocumentItem::new(
387 lib_file_uri.clone(),
388 "rust".to_string(),
389 0,
390 lib_source,
391 ),
392 },
393 )
394 .await
395 .unwrap();
396
397 let hover = server
398 .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
399 text_document_position_params: lsp_types::TextDocumentPositionParams {
400 text_document: lsp_types::TextDocumentIdentifier::new(lib_file_uri),
401 position: lsp_types::Position::new(1, 21),
402 },
403 work_done_progress_params: Default::default(),
404 })
405 .await
406 .unwrap()
407 .unwrap();
408 assert_eq!(
409 hover.contents,
410 lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
411 kind: lsp_types::MarkupKind::Markdown,
412 value: "&str".to_string()
413 })
414 );
415 }
416
417 impl LanguageServer {
418 async fn next_idle_notification(self: &Arc<Self>) {
419 let (tx, rx) = channel::unbounded();
420 let _subscription =
421 self.on_notification::<ServerStatusNotification, _>(move |params| {
422 if params.quiescent {
423 tx.try_send(()).unwrap();
424 }
425 });
426 let _ = rx.recv().await;
427 }
428 }
429
430 pub enum ServerStatusNotification {}
431
432 impl lsp_types::notification::Notification for ServerStatusNotification {
433 type Params = ServerStatusParams;
434 const METHOD: &'static str = "experimental/serverStatus";
435 }
436
437 #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
438 pub struct ServerStatusParams {
439 pub quiescent: bool,
440 }
441}